You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 41 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <bitset>
  20. #include <iostream>
  21. #include <utility>
  22. #include <fstream>
  23. #include <algorithm>
  24. #include <thread>
  25. #include "nlohmann/json.hpp"
  26. #include "backend/common/session/anf_runtime_algorithm.h"
  27. #include "include/common/utils/anfalgo.h"
  28. #include "utils/file_utils.h"
  29. #include "utils/ms_utils.h"
  30. #include "ir/manager.h"
  31. #include "ir/meta_tensor.h"
  32. #include "base/core_ops.h"
  33. #include "ir/graph_utils.h"
  34. #include "utils/ms_context.h"
  35. #include "utils/trace_base.h"
  36. #include "mindspore/ccsrc/include/common/debug/common.h"
  37. namespace mindspore {
  38. namespace kernel {
  39. constexpr char kAxis[] = "axis";
  40. constexpr char kTypeInt32[] = "Int32";
  41. constexpr auto kStridedSliceMaxDims = 8;
  42. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  43. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  44. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  45. };
  46. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  47. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  48. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  49. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  50. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  51. {"complex64", sizeof(float) * 2}};
  52. // Define all patterns here for different schedule
  53. const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
  54. {FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
  55. {FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
  56. {FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
  57. {FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
  58. {FusionType::ELEMWISE, "ElemWise"},
  59. {FusionType::PURE_BROADCAST, "PureBroadcast"},
  60. {FusionType::COMMREDUCE, "CommReduce"},
  61. {FusionType::SEGMENT, "Segment"},
  62. {FusionType::INPLACE, "Inplace"},
  63. {FusionType::MATMUL, "Matmul"},
  64. {FusionType::MATMUL_V2, "Matmul_v2"},
  65. {FusionType::GEMM, "GEMM"},
  66. {FusionType::CONV, "Convolution"},
  67. {FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
  68. {FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
  69. {FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
  70. {FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
  71. {FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
  72. {FusionType::OPAQUE, "Opaque"},
  73. {FusionType::BN_REDUCE, "bn_reduce"},
  74. {FusionType::BN_UPDATE, "bn_update"},
  75. {FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
  76. {FusionType::L2_NORMALIZE, "l2_normalize"},
  77. {FusionType::SOFTMAX, "softmax_pattern"},
  78. {FusionType::L2_LOSS, "l2_loss"},
  79. {FusionType::ASCEND_QUANT, "quant"},
  80. {FusionType::ASCEND_DEQUANT, "dequant"},
  81. {FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
  82. {FusionType::STRIDED_READ, "strided_read"},
  83. {FusionType::STRIDED_WRITE, "strided_write"},
  84. {FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
  85. {FusionType::ASCEND_REQUANT, "requant"},
  86. {FusionType::ASCEND_REQUANT_S16, "requant_s16"},
  87. {FusionType::MAX_POOL, "MaxPool"},
  88. {FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
  89. {FusionType::CONV3D, "Conv3d"},
  90. {FusionType::POOL2D, "Pool2d"},
  91. {FusionType::POOL3D, "Pool3d"},
  92. {FusionType::READ_SELECT, "read_select"},
  93. {FusionType::WRITE_SELECT, "write_select"},
  94. {FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
  95. {FusionType::DILATION_PATTERN, "dilation"},
  96. {FusionType::BROAD_CAST, "Broadcast"},
  97. {FusionType::BATCH_MATMUL, "BatchMatmul"},
  98. {FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
  99. {FusionType::DROPOUT_DOMASKV3D, "DropOutDoMaskV3D"},
  100. {FusionType::UNKNOWN_FUSION_TYPE, ""}};
  101. std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> GetAlignments(const std::string &alignment) {
  102. auto alignment_iter = MatrixDiag::AlignmentMap.find(alignment);
  103. if (alignment_iter == MatrixDiag::AlignmentMap.end()) {
  104. MS_LOG(EXCEPTION) << "For current kernel, input alignment is invalid: " << alignment
  105. << ". please limit it to {RIGHT_LEFT, LEFT_RIGHT, RIGHT_RIGHT, LEFT_LEFT}";
  106. }
  107. return alignment_iter->second;
  108. }
  109. int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_cols,
  110. const std::pair<MatrixDiag::Alignment, MatrixDiag::Alignment> &alignment) {
  111. bool right_align_super_diagonal = (alignment.first == MatrixDiag::RIGHT);
  112. bool right_align_sub_diagonal = (alignment.second == MatrixDiag::RIGHT);
  113. const bool right_align =
  114. (diag_index >= 0 && right_align_super_diagonal) || (diag_index <= 0 && right_align_sub_diagonal);
  115. const int diag_len = std::min(inner_rows + std::min(0, diag_index), inner_cols - std::max(0, diag_index));
  116. const int offset = (right_align) ? (max_diag_len - diag_len) : 0;
  117. return offset;
  118. }
  119. std::string GetFusionNameByType(const kernel::FusionType &type) {
  120. auto iter = fusion_type_name_maps.find(type);
  121. if (iter == fusion_type_name_maps.end()) {
  122. MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
  123. }
  124. return iter->second;
  125. }
  126. FusionType GetFusionTypeByName(const std::string &name) {
  127. std::string fusion_name_upper = name;
  128. transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
  129. auto iter =
  130. std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
  131. std::string name_upper = it.second;
  132. transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
  133. return fusion_name_upper == name_upper;
  134. });
  135. if (iter == fusion_type_name_maps.end()) {
  136. MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
  137. }
  138. return iter->first;
  139. }
  140. std::string GetCompilerCachePath() { return Common::GetUserDefineCachePath(); }
  141. void KernelMeta::Initialize() {
  142. auto config_path = GetCompilerCachePath();
  143. kernel_meta_path_ = config_path + std::string(kAkgKernelMeta);
  144. FileUtils::CreateNotExistDirs(kernel_meta_path_);
  145. initialized_ = true;
  146. }
  147. std::string KernelMeta::Search(const std::string &kernel_name) const {
  148. if (!initialized_) {
  149. return "";
  150. }
  151. auto iter = kernel_meta_map_.find(kernel_name);
  152. if (iter == kernel_meta_map_.end()) {
  153. return "";
  154. } else {
  155. return iter->second;
  156. }
  157. }
  158. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  159. if (!initialized_) {
  160. return false;
  161. }
  162. kernel_meta_map_[kernel_name] = kernel_json;
  163. return true;
  164. }
  165. bool CheckCache(const std::string &kernel_name) {
  166. // check cache.
  167. KernelMeta *bin_map = KernelMeta::GetInstance();
  168. if (bin_map == nullptr) {
  169. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
  170. return false;
  171. }
  172. std::string kernel_json = bin_map->Search(kernel_name);
  173. bool ret = (!kernel_json.empty());
  174. if (ret) {
  175. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
  176. } else {
  177. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
  178. }
  179. return ret;
  180. }
  181. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  182. // search cache.
  183. KernelMeta *bin_map = KernelMeta::GetInstance();
  184. if (bin_map == nullptr) {
  185. MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
  186. return nullptr;
  187. }
  188. std::string kernel_json = bin_map->Search(kernel_name);
  189. if (!kernel_json.empty()) {
  190. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  191. // just a tmp solution.
  192. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  193. MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
  194. return nullptr;
  195. } else {
  196. return kernel_pack;
  197. }
  198. } else {
  199. MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
  200. return nullptr;
  201. }
  202. }
  203. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  204. MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
  205. KernelMeta *bin_map = KernelMeta::GetInstance();
  206. std::string kernel_json = bin_map->kernel_meta_path();
  207. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  208. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  209. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  210. MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
  211. return nullptr;
  212. }
  213. if (bin_map == nullptr) {
  214. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
  215. return nullptr;
  216. }
  217. if (bin_map->Insert(kernel_name, kernel_json)) {
  218. MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
  219. }
  220. return kernel_pack;
  221. }
  222. TypeId DtypeToTypeId(const std::string &dtypes) {
  223. if (dtypes == "float") {
  224. return TypeId::kNumberTypeFloat32;
  225. }
  226. if (dtypes.empty()) {
  227. return TypeId::kMetaTypeNone;
  228. }
  229. return StringToTypeId(dtypes);
  230. }
  231. std::string Dtype2ShortType(const std::string &dtype) {
  232. auto iter = dtype_shortdtype_map_.find(dtype);
  233. if (iter != dtype_shortdtype_map_.end()) {
  234. return iter->second;
  235. } else {
  236. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  237. }
  238. }
  239. size_t GetDtypeNbyte(const std::string &dtype) {
  240. auto iter = dtype_nbyte_map.find(dtype);
  241. if (iter != dtype_nbyte_map.end()) {
  242. return iter->second;
  243. } else {
  244. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  245. }
  246. }
  247. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  248. size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
  249. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  250. MS_EXCEPTION_IF_NULL(builder);
  251. std::vector<TypeId> inputs_device_type;
  252. std::vector<std::string> inputs_format;
  253. size_t dyn_input_idx = 0;
  254. size_t kernel_info_index = 0;
  255. MS_EXCEPTION_IF_NULL(inputs[0]);
  256. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  257. for (const auto &input : inputs) {
  258. MS_EXCEPTION_IF_NULL(input);
  259. std::string param_type = input->param_type();
  260. std::vector<std::string> dtypes = input->dtypes();
  261. std::vector<std::string> formats = input->formats();
  262. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  263. MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size != formats size. dtypes size: "
  264. << dtypes.size() << ", formats size : " << formats.size();
  265. return false;
  266. }
  267. if (param_type == "dynamic") {
  268. if (dyn_input_sizes.empty()) {
  269. MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
  270. return false;
  271. }
  272. for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  273. kernel_info_index++;
  274. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  275. inputs_device_type.push_back(type_id);
  276. inputs_format.push_back(formats[builder_idex]);
  277. }
  278. dyn_input_idx++;
  279. } else if (param_type == "required") {
  280. kernel_info_index++;
  281. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  282. inputs_device_type.push_back(type_id);
  283. inputs_format.push_back(formats[builder_idex]);
  284. } else {
  285. if (kernel_info_index < real_input_num) {
  286. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  287. kernel_info_index++;
  288. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  289. inputs_device_type.push_back(type_id);
  290. inputs_format.push_back(formats[builder_idex]);
  291. }
  292. }
  293. }
  294. builder->SetInputsDeviceType(inputs_device_type);
  295. builder->SetInputsFormat(inputs_format);
  296. return true;
  297. }
  298. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  299. const size_t &real_output_num,
  300. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  301. // not now but in the next we need to support dynamic output case
  302. MS_EXCEPTION_IF_NULL(builder);
  303. size_t output_idx = 0;
  304. std::vector<TypeId> outputs_device_type;
  305. std::vector<std::string> outputs_format;
  306. MS_EXCEPTION_IF_NULL(outputs[0]);
  307. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  308. for (const auto &output : outputs) {
  309. MS_EXCEPTION_IF_NULL(output);
  310. if (output_idx >= real_output_num) {
  311. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  312. continue;
  313. }
  314. size_t output_num = 0;
  315. if (output->param_type() == "dynamic") {
  316. if (outputs.size() > 1) {
  317. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  318. }
  319. output_num = real_output_num;
  320. } else if (output->param_type() == "required") {
  321. output_num = 1;
  322. } else {
  323. if (output_idx < real_output_num) {
  324. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  325. output_num = 1;
  326. }
  327. }
  328. for (size_t i = 0; i < output_num; i++) {
  329. std::vector<std::string> dtypes = output->dtypes();
  330. std::vector<std::string> formats = output->formats();
  331. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  332. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  333. return false;
  334. }
  335. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  336. outputs_device_type.push_back(type_id);
  337. outputs_format.push_back(formats[builder_idex]);
  338. output_idx++;
  339. }
  340. }
  341. builder->SetOutputsFormat(outputs_format);
  342. builder->SetOutputsDeviceType(outputs_device_type);
  343. return true;
  344. }
  345. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  346. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  347. MS_EXCEPTION_IF_NULL(builder);
  348. MS_EXCEPTION_IF_NULL(op_info_ptr);
  349. auto imply_type = op_info_ptr->imply_type();
  350. builder->SetProcessor(processor);
  351. std::string fusion_name = op_info_ptr->fusion_type();
  352. auto fusion_type = GetFusionTypeByName(fusion_name);
  353. builder->SetFusionType(fusion_type);
  354. if (imply_type == kAKG) {
  355. builder->SetKernelType(AKG_KERNEL);
  356. } else if (imply_type == kGPU) {
  357. builder->SetKernelType(GPU_KERNEL);
  358. } else if (imply_type == kCPU) {
  359. builder->SetKernelType(CPU_KERNEL);
  360. } else if (imply_type == kAICPU) {
  361. builder->SetKernelType(AICPU_KERNEL);
  362. } else {
  363. builder->SetKernelType(TBE_KERNEL);
  364. }
  365. }
  366. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  367. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  368. MS_EXCEPTION_IF_NULL(kernel_node);
  369. MS_EXCEPTION_IF_NULL(kernel_info_list);
  370. size_t real_input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
  371. size_t real_output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
  372. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  373. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  374. std::vector<int64_t> dyn_input_sizes;
  375. auto primitive = common::AnfAlgo::GetCNodePrimitive(kernel_node);
  376. MS_EXCEPTION_IF_NULL(primitive);
  377. auto op_name = common::AnfAlgo::GetCNodeName(kernel_node);
  378. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  379. dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
  380. }
  381. if (inputs.size() > 0) {
  382. if (inputs[0] == nullptr) {
  383. MS_LOG(EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
  384. }
  385. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  386. for (size_t j = 0; j < kernel_info_cnt; j++) {
  387. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  388. MS_EXCEPTION_IF_NULL(builder);
  389. SetKernelBuildInfo(builder, processor, op_info_ptr);
  390. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  391. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
  392. return false;
  393. }
  394. if (outputs.size() > 0) {
  395. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  396. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  397. return false;
  398. }
  399. }
  400. kernel_info_list->push_back(builder->Build());
  401. }
  402. } else if (outputs.size() > 0) {
  403. if (outputs[0] == nullptr) {
  404. MS_LOG(EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
  405. }
  406. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  407. for (size_t j = 0; j < kernel_info_cnt; j++) {
  408. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  409. MS_EXCEPTION_IF_NULL(builder);
  410. SetKernelBuildInfo(builder, processor, op_info_ptr);
  411. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  412. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  413. return false;
  414. }
  415. kernel_info_list->push_back(builder->Build());
  416. }
  417. } else {
  418. if (processor == AICPU) {
  419. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  420. MS_EXCEPTION_IF_NULL(builder);
  421. SetKernelBuildInfo(builder, processor, op_info_ptr);
  422. kernel_info_list->push_back(builder->Build());
  423. }
  424. }
  425. return true;
  426. }
  427. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  428. std::string path = base_path + json_name + kInfoSuffix;
  429. auto realpath = Common::CreatePrefixPath(path);
  430. if (!realpath.has_value()) {
  431. MS_LOG(ERROR) << "Get real path failed, path=" << path;
  432. return;
  433. }
  434. ChangeFileMode(realpath.value(), S_IWUSR);
  435. std::ofstream filewrite(realpath.value());
  436. if (!filewrite.is_open()) {
  437. MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
  438. return;
  439. }
  440. filewrite << info << std::endl;
  441. filewrite.close();
  442. ChangeFileMode(realpath.value(), S_IRUSR);
  443. }
  444. Processor GetProcessor(const string &processor) {
  445. if (processor == kProcessorAiCore) return Processor::AICORE;
  446. if (processor == kProcessorAiCpu) return Processor::AICPU;
  447. if (processor == kProcessorCuda) return Processor::CUDA;
  448. MS_LOG(DEBUG) << "Unknown processor type.";
  449. return Processor::UNKNOWN;
  450. }
  451. std::string GetProcessor(const AnfNodePtr &anf_node) {
  452. MS_EXCEPTION_IF_NULL(anf_node);
  453. std::string device;
  454. switch (AnfAlgo::GetProcessor(anf_node)) {
  455. case Processor::AICORE:
  456. device = kProcessorAiCore;
  457. break;
  458. case Processor::AICPU:
  459. device = kProcessorAiCpu;
  460. break;
  461. case Processor::CUDA:
  462. device = kProcessorCuda;
  463. break;
  464. default:
  465. MS_LOG(DEBUG) << "Unknown processor type.";
  466. break;
  467. }
  468. return device;
  469. }
  470. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  471. if (shape_a.size() != shape_b.size()) {
  472. return false;
  473. }
  474. for (size_t i = 0; i < shape_a.size(); ++i) {
  475. if (shape_a[i] != shape_b[i]) {
  476. return false;
  477. }
  478. }
  479. return true;
  480. }
  481. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  482. const std::vector<AnfNodePtr> &input_list,
  483. const std::vector<AnfNodePtr> &output_list) {
  484. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  485. for (size_t i = 0; i < output_list.size(); ++i) {
  486. auto const &output = output_list[i];
  487. MS_EXCEPTION_IF_NULL(output);
  488. bool found = false;
  489. auto pree_node = common::AnfAlgo::VisitKernel(output, 0);
  490. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  491. if (pos != std::end(node_list)) {
  492. output_index.push_back(pree_node);
  493. continue;
  494. }
  495. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  496. if (ret != std::end(input_list)) {
  497. output_index.push_back(std::make_pair(pree_node.first, 0));
  498. found = true;
  499. }
  500. if (!found) {
  501. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  502. << output->func_graph()->ToString() << "] found no related kernel info.";
  503. }
  504. }
  505. return output_index;
  506. }
  507. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  508. MS_EXCEPTION_IF_NULL(node_list);
  509. MS_EXCEPTION_IF_NULL(func_graph);
  510. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  511. for (auto const &node : node_lists) {
  512. if (!AnfUtils::IsRealKernel(node) || !node->isa<CNode>()) {
  513. continue;
  514. }
  515. auto cnode = node->cast<CNodePtr>();
  516. MS_EXCEPTION_IF_NULL(cnode);
  517. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  518. node_list->push_back(node);
  519. }
  520. }
  521. }
  522. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  523. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  524. MS_EXCEPTION_IF_NULL(func_graph);
  525. MS_EXCEPTION_IF_NULL(node_list);
  526. MS_EXCEPTION_IF_NULL(input_list);
  527. GetValidKernelNodes(func_graph, node_list);
  528. auto parameters = func_graph->parameters();
  529. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  530. GetFuncGraphOutputNodes(func_graph, output_list);
  531. }
  532. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  533. MS_EXCEPTION_IF_NULL(func_graph);
  534. MS_EXCEPTION_IF_NULL(output_list);
  535. auto func_output = func_graph->output();
  536. MS_EXCEPTION_IF_NULL(func_output);
  537. if (func_output->isa<CNode>()) {
  538. // multi output.
  539. auto cnode = func_output->cast<CNodePtr>();
  540. MS_EXCEPTION_IF_NULL(cnode);
  541. auto input0 = cnode->input(kAnfPrimitiveIndex);
  542. MS_EXCEPTION_IF_NULL(input0);
  543. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  544. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  545. auto input_node = cnode->input(input_idx);
  546. MS_EXCEPTION_IF_NULL(input_node);
  547. if (input_node->isa<CNode>() && common::AnfAlgo::GetInputTensorNum(input_node) == 0) {
  548. continue;
  549. }
  550. output_list->push_back(common::AnfAlgo::VisitKernel(input_node, 0).first);
  551. }
  552. } else {
  553. // single output.
  554. output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
  555. }
  556. } else {
  557. // single output.
  558. output_list->push_back(common::AnfAlgo::VisitKernel(func_output, 0).first);
  559. }
  560. }
  561. bool IsWeightBoundary(const AnfNodePtr &node) {
  562. if (node->isa<ValueNode>()) {
  563. return true;
  564. }
  565. if (node->isa<Parameter>() && common::AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  566. return true;
  567. }
  568. return false;
  569. }
  570. std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
  571. if (common::AnfAlgo::GetInputTensorNum(cnode) != 1 || common::AnfAlgo::GetOutputTensorNum(cnode) != 1) {
  572. MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output."
  573. << trace::DumpSourceLines(cnode);
  574. }
  575. std::vector<int64_t> axis;
  576. auto input_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  577. auto primitive = common::AnfAlgo::GetCNodePrimitive(cnode);
  578. MS_EXCEPTION_IF_NULL(primitive);
  579. auto axis_attr = primitive->GetAttr(kAxis);
  580. if (axis_attr == nullptr) {
  581. MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
  582. return std::vector<int64_t>();
  583. }
  584. std::vector<int64_t> axis_list;
  585. if (axis_attr->isa<Int64Imm>()) {
  586. (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
  587. } else {
  588. axis_list = GetValue<std::vector<int64_t>>(axis_attr);
  589. }
  590. for (const auto &elem : axis_list) {
  591. if (elem < 0) {
  592. (void)axis.emplace_back(input_shape.size() + elem);
  593. } else {
  594. (void)axis.emplace_back(elem);
  595. }
  596. }
  597. common::AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  598. return axis;
  599. }
  600. void FillEmptyDims(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  601. std::vector<int64_t> *stride, std::vector<size_t> *input_shape) {
  602. std::vector<int64_t> &_begin = *begin;
  603. std::vector<int64_t> &_end = *end;
  604. std::vector<int64_t> &_stride = *stride;
  605. std::vector<size_t> &_input_shape = *input_shape;
  606. if (_begin.size() != _end.size() || _begin.size() != _stride.size() || _begin.size() > _input_shape.size()) {
  607. MS_LOG(EXCEPTION) << "For '" << common::AnfAlgo::GetCNodeName(kernel_node)
  608. << "', the length of 'begin', 'stride' and 'end' should be equal "
  609. "and less than or equal to the dimension of 'input_x', but got the length of 'begin': "
  610. << _begin.size() << ", the length of 'stride': " << _stride.size()
  611. << ", the length of 'end': " << _end.size()
  612. << ", the dimension of 'input_x': " << _input_shape.size();
  613. }
  614. for (size_t i = 0; i < kStridedSliceMaxDims; i++) {
  615. if (i >= _input_shape.size()) {
  616. _input_shape.push_back(1);
  617. }
  618. if (i < _begin.size()) {
  619. int64_t dim = SizeToLong(_input_shape[i]);
  620. _begin[i] = std::min(_begin[i] < 0 ? std::max(_begin[i] + dim, static_cast<int64_t>(0)) : _begin[i], dim - 1);
  621. } else {
  622. _begin.push_back(0);
  623. }
  624. if (i < _end.size()) {
  625. int64_t dim = SizeToLong(_input_shape[i]);
  626. _end[i] = std::max(_end[i] < 0 ? _end[i] + dim : std::min(_end[i], dim), static_cast<int64_t>(-1));
  627. } else {
  628. _end.push_back(i < _input_shape.size() ? SizeToLong(_input_shape[i]) : 1);
  629. }
  630. if (i >= _stride.size()) {
  631. _stride.push_back(1);
  632. }
  633. }
  634. }
  635. std::vector<bool> Dec2Bin(const int64_t &mask) {
  636. auto mask_str = std::bitset<kStridedSliceMaxDims>(mask).to_string();
  637. int64_t dim_idx = 0;
  638. std::vector<bool> result(kStridedSliceMaxDims, false);
  639. for (int64_t i = mask_str.size() - 1; i >= 0; i--) {
  640. if (mask_str[i] == '1') {
  641. result[dim_idx] = true;
  642. }
  643. dim_idx++;
  644. }
  645. return result;
  646. }
  647. void ComputeBeginMask(const CNodePtr &kernel_node, std::vector<int64_t> *begin, const std::vector<int64_t> &stride,
  648. const std::vector<size_t> &input_shape) {
  649. std::vector<int64_t> &_begin = *begin;
  650. auto begin_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrBeginMask);
  651. auto begin_mask = Dec2Bin(begin_mask_int);
  652. for (size_t i = 0; i < begin_mask.size(); i++) {
  653. if (i < kStridedSliceMaxDims && begin_mask[i]) {
  654. _begin[i] = stride[i] < 0 ? SizeToLong(input_shape[i]) - 1 : 0;
  655. }
  656. }
  657. }
  658. void ComputeEndMask(const CNodePtr &kernel_node, std::vector<int64_t> *end, const std::vector<int64_t> &stride,
  659. const std::vector<size_t> &input_shape) {
  660. std::vector<int64_t> &_end = *end;
  661. auto end_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEndMask);
  662. auto end_mask = Dec2Bin(end_mask_int);
  663. for (size_t j = 0; j < end_mask.size(); j++) {
  664. if (j < kStridedSliceMaxDims && end_mask[j]) {
  665. _end[j] = stride[j] < 0 ? -1 : SizeToLong(input_shape[j]);
  666. }
  667. }
  668. }
  669. void ComputeEllipsisMask(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  670. std::vector<int64_t> *stride, const std::vector<size_t> &input_shape) {
  671. std::vector<int64_t> &_begin = *begin;
  672. std::vector<int64_t> &_end = *end;
  673. std::vector<int64_t> &_stride = *stride;
  674. auto ellipsis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrEllipsisMask);
  675. auto ellipsis_mask = Dec2Bin(ellipsis_mask_int);
  676. for (size_t k = 0; k < ellipsis_mask.size(); k++) {
  677. if (k < kStridedSliceMaxDims && ellipsis_mask[k]) {
  678. _begin[k] = 0;
  679. _end[k] = SizeToLong(input_shape[k]);
  680. _stride[k] = 1;
  681. }
  682. }
  683. }
  684. void ComputNewAxisMask(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  685. std::vector<int64_t> *stride, const std::vector<size_t> &input_shape) {
  686. std::vector<int64_t> &_begin = *begin;
  687. std::vector<int64_t> &_end = *end;
  688. std::vector<int64_t> &_stride = *stride;
  689. auto new_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrNewAxisMask);
  690. auto new_axis_mask = Dec2Bin(new_axis_mask_int);
  691. for (size_t l = 0; l < new_axis_mask.size(); l++) {
  692. if (l < kStridedSliceMaxDims && new_axis_mask[l]) {
  693. _begin[l] = 0;
  694. _end[l] = SizeToLong(input_shape[l]);
  695. _stride[l] = 1;
  696. }
  697. }
  698. }
  699. void ComputShrinkAxisMask(const CNodePtr &kernel_node, const std::vector<int64_t> &begin, std::vector<int64_t> *end,
  700. std::vector<int64_t> *stride) {
  701. std::vector<int64_t> &_end = *end;
  702. std::vector<int64_t> &_stride = *stride;
  703. auto shrink_axis_mask_int = common::AnfAlgo::GetNodeAttr<int64_t>(kernel_node, kAttrShrinkAxisMask);
  704. auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
  705. for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
  706. if (m < kStridedSliceMaxDims && shrink_axis_mask[m]) {
  707. _end[m] = _end[m] > begin[m] ? begin[m] + 1 : begin[m] - 1;
  708. _stride[m] = _end[m] > begin[m] ? 1 : -1;
  709. }
  710. }
  711. }
  712. void ParseStrideSliceMasks(const CNodePtr &kernel_node, std::vector<int64_t> *begin, std::vector<int64_t> *end,
  713. std::vector<int64_t> *stride, const std::vector<size_t> &input_shape) {
  714. ComputeBeginMask(kernel_node, begin, *stride, input_shape);
  715. ComputeEndMask(kernel_node, end, *stride, input_shape);
  716. ComputeEllipsisMask(kernel_node, begin, end, stride, input_shape);
  717. ComputNewAxisMask(kernel_node, begin, end, stride, input_shape);
  718. ComputShrinkAxisMask(kernel_node, *begin, end, stride);
  719. }
  720. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  721. MS_EXCEPTION_IF_NULL(anf_node);
  722. std::string processor = kProcessorUnknown;
  723. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  724. MS_EXCEPTION_IF_NULL(kernel_info);
  725. auto build_info = kernel_info->select_kernel_build_info();
  726. // we may call this before kernel select.
  727. if (build_info == nullptr) {
  728. return processor;
  729. }
  730. switch (build_info->processor()) {
  731. case Processor::AICORE:
  732. processor = kProcessorAiCore;
  733. break;
  734. case Processor::AICPU:
  735. processor = kProcessorAiCpu;
  736. break;
  737. case Processor::CUDA:
  738. processor = kProcessorCuda;
  739. break;
  740. default:
  741. MS_LOG(ERROR) << "Unknown processor type.";
  742. break;
  743. }
  744. return processor;
  745. }
  746. Processor GetProcessorFromContext() {
  747. kernel::Processor processor = kernel::Processor::UNKNOWN;
  748. auto context_ptr = MsContext::GetInstance();
  749. MS_EXCEPTION_IF_NULL(context_ptr);
  750. auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  751. if (device_info == kGPUDevice) {
  752. processor = kernel::Processor::CUDA;
  753. } else if (device_info == kAscendDevice) {
  754. processor = kernel::Processor::AICORE;
  755. } else if (device_info == kCPUDevice) {
  756. processor = kernel::Processor::CPU;
  757. }
  758. return processor;
  759. }
  760. std::string GetStrProcessorFromContext() {
  761. auto processor = GetProcessorFromContext();
  762. string str_processor = kernel::kProcessorUnknown;
  763. if (processor == kernel::Processor::CUDA) {
  764. str_processor = kernel::kProcessorCuda;
  765. } else if (processor == kernel::Processor::AICORE) {
  766. str_processor = kernel::kProcessorAiCore;
  767. } else if (processor == kernel::Processor::CPU) {
  768. str_processor = kernel::kProcessorCpu;
  769. }
  770. return str_processor;
  771. }
  772. float Scaling(size_t in_size, size_t out_size, bool align_corners) {
  773. return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
  774. : in_size / static_cast<float>(out_size);
  775. }
  776. float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
  777. void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
  778. CachedInterpolation *interpolation) {
  779. interpolation[out_size].lower = 0;
  780. interpolation[out_size].upper = 0;
  781. for (size_t i = 0; i <= out_size - 1; ++i) {
  782. const float in = ScaleGrid(i, scale);
  783. const float in_f = std::floor(in);
  784. interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
  785. interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
  786. interpolation[i].lerp = in - in_f;
  787. }
  788. }
  789. bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i) {
  790. MS_EXCEPTION_IF_NULL(type_ptr);
  791. size_t type_byte = GetTypeByte(type_ptr);
  792. if (type_byte == 0) {
  793. return false;
  794. }
  795. for (size_t j = 0; j < shape.size(); j++) {
  796. size_i[0] = LongMulWithOverflowCheck(size_i[0], static_cast<int64_t>(shape[j]));
  797. }
  798. size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
  799. return true;
  800. }
  801. void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
  802. MS_EXCEPTION_IF_NULL(long_shape);
  803. (void)std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
  804. }
  805. void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  806. const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
  807. if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
  808. MS_LOG(EXCEPTION)
  809. << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
  810. }
  811. size_t size = start.size();
  812. for (size_t i = 0; i < size; ++i) {
  813. if (stop[i] <= start[i]) {
  814. MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
  815. }
  816. // Operator need to be generalized in the future. Only support to copy continuous memory now.
  817. if (step[i] != 1) {
  818. MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
  819. }
  820. }
  821. size_t slice_pos = size;
  822. for (size_t i = 0; i < size; ++i) {
  823. if (stop[i] - start[i] > 1) {
  824. slice_pos = i;
  825. break;
  826. }
  827. }
  828. for (size_t i = slice_pos + 1; i < size; ++i) {
  829. if (stop[i] - start[i] != input_shape[i]) {
  830. MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
  831. "but tensor[0:100, 0] is not supported.";
  832. }
  833. }
  834. }
  835. size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
  836. const std::vector<int64_t> &stop) {
  837. for (size_t i = 0; i < start.size(); ++i) {
  838. if (stop[i] - start[i] != 1) {
  839. return SizetMulWithOverflowCheck(LongToSize(stop[i] - start[i]), LongToSize(dim_offset[i]));
  840. }
  841. }
  842. return LongToSize(dim_offset[start.size() - 1]);
  843. }
  844. std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
  845. std::vector<int64_t> dim_offset;
  846. int64_t offset = 1;
  847. for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
  848. dim_offset.push_back(offset);
  849. offset = offset * (*iter);
  850. }
  851. std::reverse(dim_offset.begin(), dim_offset.end());
  852. return dim_offset;
  853. }
  854. size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  855. const std::vector<int64_t> &dim_offset) {
  856. size_t size = start.size();
  857. size_t offset = 0;
  858. for (size_t i = 0; i < size; ++i) {
  859. offset += SizetMulWithOverflowCheck(LongToSize(dim_offset[i]), LongToSize(start[i]));
  860. if (stop[i] - start[i] != 1) {
  861. break;
  862. }
  863. }
  864. return offset;
  865. }
  866. size_t UnitSizeInBytes(const mindspore::TypeId &t) {
  867. size_t bytes = 0;
  868. switch (t) {
  869. case kNumberTypeBool:
  870. case kNumberTypeInt8:
  871. case kNumberTypeUInt8:
  872. bytes = sizeof(int8_t);
  873. break;
  874. case kNumberTypeInt16:
  875. case kNumberTypeUInt16:
  876. case kNumberTypeFloat16:
  877. bytes = sizeof(int16_t);
  878. break;
  879. case kNumberTypeInt:
  880. case kNumberTypeUInt:
  881. case kNumberTypeInt32:
  882. case kNumberTypeUInt32:
  883. case kNumberTypeFloat:
  884. case kNumberTypeFloat32:
  885. bytes = sizeof(int32_t);
  886. break;
  887. case kNumberTypeUInt64:
  888. case kNumberTypeInt64:
  889. case kNumberTypeFloat64:
  890. bytes = sizeof(int64_t);
  891. break;
  892. case kNumberTypeInt4:
  893. default:
  894. MS_LOG(EXCEPTION) << "Invalid types " << t;
  895. }
  896. return bytes;
  897. }
  898. KernelAttr &KernelAttr::AddInputAttr(const TypeId &ms_type, const std::string &format) {
  899. (void)input_type_.emplace_back(ms_type, format);
  900. return *this;
  901. }
  902. KernelAttr &KernelAttr::AddOutputAttr(const TypeId &ms_type, const std::string &format) {
  903. (void)output_type_.emplace_back(ms_type, format);
  904. return *this;
  905. }
  906. KernelAttr &KernelAttr::AddAllSameAttr(const bool &all_same) {
  907. all_same_ = all_same;
  908. return *this;
  909. }
  910. KernelAttr &KernelAttr::AddOutInRef(size_t output_index, size_t input_index) {
  911. out_in_ref_map_[output_index] = input_index;
  912. return *this;
  913. }
  914. void KernelAttr::SetInputAttrList(const std::vector<DataType> &addr_list) {
  915. input_type_.assign(addr_list.begin(), addr_list.end());
  916. }
  917. std::ostream &operator<<(std::ostream &os, KernelAttr kernel_attr) {
  918. std::stringstream ss;
  919. ss << "[Kernel Attr] all same: " << kernel_attr.GetAllSame();
  920. size_t input_num = kernel_attr.GetInputSize();
  921. if (input_num > 0) {
  922. ss << ", input(";
  923. for (size_t i = 0; i < input_num; ++i) {
  924. ss << TypeIdLabel(kernel_attr.GetInputAttr(i).first);
  925. if (i != input_num - 1) {
  926. ss << ",";
  927. }
  928. }
  929. ss << ") ";
  930. }
  931. size_t output_num = kernel_attr.GetOutputSize();
  932. if (output_num > 0) {
  933. ss << ", output(";
  934. for (size_t i = 0; i < output_num; ++i) {
  935. ss << TypeIdLabel(kernel_attr.GetOutputAttr(i).first);
  936. if (i != output_num - 1) {
  937. ss << ",";
  938. }
  939. }
  940. ss << ").";
  941. }
  942. return os << ss.str();
  943. }
  944. std::pair<bool, size_t> MatchKernelAttr(const KernelAttr &kernel_attr,
  945. const std::vector<KernelAttr> &kernel_attr_list) {
  946. // kernel_attr should not be all same. If so, then return false.
  947. if (kernel_attr.GetAllSame()) {
  948. return std::make_pair(false, 0);
  949. }
  950. auto input_num = kernel_attr.GetInputSize();
  951. auto output_num = kernel_attr.GetOutputSize();
  952. for (size_t index = 0; index < kernel_attr_list.size(); ++index) {
  953. const auto &cur_kernel_attr = kernel_attr_list[index];
  954. auto cur_input_num = cur_kernel_attr.GetInputSize();
  955. auto cur_output_num = cur_kernel_attr.GetOutputSize();
  956. if (!cur_kernel_attr.GetAllSame() && (input_num != cur_input_num || output_num != cur_output_num)) {
  957. continue;
  958. }
  959. bool mis_match = false;
  960. for (size_t i = 0; i < input_num; ++i) {
  961. auto dtype = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
  962. auto format = cur_kernel_attr.GetInputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
  963. if (kernel_attr.GetInputAttr(i).first != dtype || kernel_attr.GetInputAttr(i).second != format) {
  964. mis_match = true;
  965. break;
  966. }
  967. }
  968. if (mis_match) {
  969. continue;
  970. }
  971. for (size_t i = 0; i < output_num; ++i) {
  972. auto dtype = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).first;
  973. auto format = cur_kernel_attr.GetOutputAttr(cur_kernel_attr.GetAllSame() ? 0 : i).second;
  974. if (kernel_attr.GetOutputAttr(i).first != dtype || kernel_attr.GetOutputAttr(i).second != format) {
  975. mis_match = true;
  976. break;
  977. }
  978. }
  979. if (!mis_match) {
  980. return std::make_pair(true, index);
  981. }
  982. }
  983. return std::make_pair(false, 0);
  984. }
  985. KernelAttr GetKernelAttrFromBuildInfo(const KernelBuildInfoPtr &build_info) {
  986. MS_EXCEPTION_IF_NULL(build_info);
  987. KernelAttr kernel_attr;
  988. for (size_t i = 0; i < build_info->GetInputNum(); i++) {
  989. (void)kernel_attr.AddInputAttr(build_info->GetInputDeviceType(i), build_info->GetInputFormat(i));
  990. }
  991. for (size_t j = 0; j < build_info->GetOutputNum(); j++) {
  992. (void)kernel_attr.AddOutputAttr(build_info->GetOutputDeviceType(j), build_info->GetOutputFormat(j));
  993. }
  994. return kernel_attr;
  995. }
  996. KernelAttr GetKernelAttrFromNode(const AnfNodePtr &kernel_node) {
  997. auto build_info = AnfAlgo::GetSelectKernelBuildInfo(kernel_node);
  998. return GetKernelAttrFromBuildInfo(build_info);
  999. }
  1000. } // namespace kernel
  1001. } // namespace mindspore