You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 40 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_utils.h"
  27. #include "ir/manager.h"
  28. #include "ir/meta_tensor.h"
  29. #include "base/core_ops.h"
  30. #include "ir/graph_utils.h"
  31. #include "utils/ms_context.h"
  32. #include "mindspore/ccsrc/debug/common.h"
  33. namespace mindspore {
  34. namespace kernel {
  35. constexpr char kAxis[] = "axis";
  36. constexpr char kTypeInt32[] = "Int32";
  37. const std::unordered_map<std::string, TypeId> type_id_maps = {{"float", TypeId::kNumberTypeFloat32},
  38. {"float16", TypeId::kNumberTypeFloat16},
  39. {"float32", TypeId::kNumberTypeFloat32},
  40. {"float64", TypeId::kNumberTypeFloat64},
  41. {"int", TypeId::kNumberTypeInt},
  42. {"int8", TypeId::kNumberTypeInt8},
  43. {"int16", TypeId::kNumberTypeInt16},
  44. {"int32", TypeId::kNumberTypeInt32},
  45. {"int64", TypeId::kNumberTypeInt64},
  46. {"uint", TypeId::kNumberTypeUInt},
  47. {"uint8", TypeId::kNumberTypeUInt8},
  48. {"uint16", TypeId::kNumberTypeUInt16},
  49. {"uint32", TypeId::kNumberTypeUInt32},
  50. {"uint64", TypeId::kNumberTypeUInt64},
  51. {"bool", TypeId::kNumberTypeBool},
  52. {"complex64", TypeId::kNumberTypeComplex64},
  53. {"complex128", TypeId::kNumberTypeComplex128}};
  54. const std::map<TypeId, std::string> type_id_str_map = {{TypeId::kNumberTypeFloat32, "float32"},
  55. {TypeId::kNumberTypeFloat16, "float16"},
  56. {TypeId::kNumberTypeFloat, "float"},
  57. {TypeId::kNumberTypeFloat64, "float64"},
  58. {TypeId::kNumberTypeInt, "int"},
  59. {TypeId::kNumberTypeInt8, "int8"},
  60. {TypeId::kNumberTypeInt16, "int16"},
  61. {TypeId::kNumberTypeInt32, "int32"},
  62. {TypeId::kNumberTypeInt64, "int64"},
  63. {TypeId::kNumberTypeUInt, "uint"},
  64. {TypeId::kNumberTypeUInt8, "uint8"},
  65. {TypeId::kNumberTypeUInt16, "uint16"},
  66. {TypeId::kNumberTypeUInt32, "uint32"},
  67. {TypeId::kNumberTypeUInt64, "uint64"},
  68. {TypeId::kNumberTypeBool, "bool"},
  69. {TypeId::kNumberTypeComplex64, "complex64"},
  70. {TypeId::kNumberTypeComplex128, "complex128"}};
  71. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  72. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  73. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  74. };
  75. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  76. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  77. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  78. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  79. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  80. {"complex64", sizeof(float) * 2}};
  81. // Define all patterns here for different schedule
  82. const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
  83. {FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
  84. {FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
  85. {FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
  86. {FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
  87. {FusionType::ELEMWISE, "ElemWise"},
  88. {FusionType::PURE_BROADCAST, "PureBroadcast"},
  89. {FusionType::COMMREDUCE, "CommReduce"},
  90. {FusionType::SEGMENT, "Segment"},
  91. {FusionType::INPLACE, "Inplace"},
  92. {FusionType::MATMUL, "Matmul"},
  93. {FusionType::MATMUL_V2, "Matmul_v2"},
  94. {FusionType::GEMM, "GEMM"},
  95. {FusionType::CONV, "Convolution"},
  96. {FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
  97. {FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
  98. {FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
  99. {FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
  100. {FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
  101. {FusionType::OPAQUE, "Opaque"},
  102. {FusionType::BN_REDUCE, "bn_reduce"},
  103. {FusionType::BN_UPDATE, "bn_update"},
  104. {FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
  105. {FusionType::L2_NORMALIZE, "l2_normalize"},
  106. {FusionType::SOFTMAX, "softmax_pattern"},
  107. {FusionType::L2_LOSS, "l2_loss"},
  108. {FusionType::ASCEND_QUANT, "quant"},
  109. {FusionType::ASCEND_DEQUANT, "dequant"},
  110. {FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
  111. {FusionType::STRIDED_READ, "strided_read"},
  112. {FusionType::STRIDED_WRITE, "strided_write"},
  113. {FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
  114. {FusionType::ASCEND_REQUANT, "requant"},
  115. {FusionType::ASCEND_REQUANT_S16, "requant_s16"},
  116. {FusionType::MAX_POOL, "MaxPool"},
  117. {FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
  118. {FusionType::CONV3D, "Conv3d"},
  119. {FusionType::POOL2D, "Pool2d"},
  120. {FusionType::POOL3D, "Pool3d"},
  121. {FusionType::READ_SELECT, "read_select"},
  122. {FusionType::WRITE_SELECT, "write_select"},
  123. {FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
  124. {FusionType::DILATION_PATTERN, "dilation"},
  125. {FusionType::BROAD_CAST, "Broadcast"},
  126. {FusionType::BATCH_MATMUL, "BatchMatmul"},
  127. {FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
  128. {FusionType::UNKNOWN_FUSION_TYPE, ""}};
  129. std::string GetFusionNameByType(const kernel::FusionType &type) {
  130. auto iter = fusion_type_name_maps.find(type);
  131. if (iter == fusion_type_name_maps.end()) {
  132. MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
  133. }
  134. return iter->second;
  135. }
  136. FusionType GetFusionTypeByName(const std::string &name) {
  137. std::string fusion_name_upper = name;
  138. transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
  139. auto iter =
  140. std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
  141. std::string name_upper = it.second;
  142. transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
  143. return fusion_name_upper == name_upper;
  144. });
  145. if (iter == fusion_type_name_maps.end()) {
  146. MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
  147. }
  148. return iter->first;
  149. }
  150. void KernelMeta::Initialize() {
  151. kernel_meta_path_ = std::string(kGpuKernelMeta) + "/";
  152. #if defined(_WIN32) || defined(_WIN64)
  153. auto ret = mkdir(kernel_meta_path_.c_str());
  154. #else
  155. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  156. #endif
  157. if (ret != 0) {
  158. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  159. }
  160. initialized_ = true;
  161. }
  162. std::string KernelMeta::Search(const std::string &kernel_name) const {
  163. if (!initialized_) {
  164. return "";
  165. }
  166. auto iter = kernel_meta_map_.find(kernel_name);
  167. if (iter == kernel_meta_map_.end()) {
  168. return "";
  169. } else {
  170. return iter->second;
  171. }
  172. }
  173. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  174. if (!initialized_) {
  175. return false;
  176. }
  177. kernel_meta_map_[kernel_name] = kernel_json;
  178. return true;
  179. }
  180. bool CheckCache(const std::string &kernel_name) {
  181. // check cache.
  182. KernelMeta *bin_map = KernelMeta::GetInstance();
  183. if (bin_map == nullptr) {
  184. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
  185. return false;
  186. }
  187. std::string kernel_json = bin_map->Search(kernel_name);
  188. bool ret = (!kernel_json.empty());
  189. if (ret) {
  190. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
  191. } else {
  192. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
  193. }
  194. return ret;
  195. }
  196. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  197. // search cache.
  198. KernelMeta *bin_map = KernelMeta::GetInstance();
  199. if (bin_map == nullptr) {
  200. MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
  201. return nullptr;
  202. }
  203. std::string kernel_json = bin_map->Search(kernel_name);
  204. if (!kernel_json.empty()) {
  205. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  206. // just a tmp solution.
  207. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  208. MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
  209. return nullptr;
  210. } else {
  211. return kernel_pack;
  212. }
  213. } else {
  214. MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
  215. return nullptr;
  216. }
  217. }
  218. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  219. MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
  220. KernelMeta *bin_map = KernelMeta::GetInstance();
  221. std::string kernel_json;
  222. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  223. kernel_json = kCceKernelMeta;
  224. } else {
  225. kernel_json = bin_map->kernel_meta_path();
  226. }
  227. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  228. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  229. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  230. MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
  231. return nullptr;
  232. }
  233. if (bin_map == nullptr) {
  234. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
  235. return nullptr;
  236. }
  237. if (bin_map->Insert(kernel_name, kernel_json)) {
  238. MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
  239. }
  240. return kernel_pack;
  241. }
  242. TypeId DtypeToTypeId(const std::string &dtypes) {
  243. auto iter = type_id_maps.find(dtypes);
  244. if (iter != type_id_maps.end()) {
  245. return iter->second;
  246. } else {
  247. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  248. }
  249. }
  250. std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
  251. auto iter = type_id_str_map.find(type_id);
  252. if (iter == type_id_str_map.end()) {
  253. if (!unknown_as_default) {
  254. MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
  255. }
  256. MS_LOG(INFO) << "Using default dtype: float32";
  257. return "float32";
  258. }
  259. return iter->second;
  260. }
  261. std::string Dtype2ShortType(const std::string &dtype) {
  262. auto iter = dtype_shortdtype_map_.find(dtype);
  263. if (iter != dtype_shortdtype_map_.end()) {
  264. return iter->second;
  265. } else {
  266. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  267. }
  268. }
  269. size_t GetDtypeNbyte(const std::string &dtype) {
  270. auto iter = dtype_nbyte_map.find(dtype);
  271. if (iter != dtype_nbyte_map.end()) {
  272. return iter->second;
  273. } else {
  274. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  275. }
  276. }
  277. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  278. size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
  279. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  280. MS_EXCEPTION_IF_NULL(builder);
  281. std::vector<TypeId> inputs_device_type;
  282. std::vector<std::string> inputs_format;
  283. size_t dyn_input_idx = 0;
  284. size_t kernel_info_index = 0;
  285. MS_EXCEPTION_IF_NULL(inputs[0]);
  286. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  287. for (const auto &input : inputs) {
  288. MS_EXCEPTION_IF_NULL(input);
  289. std::string param_type = input->param_type();
  290. std::vector<std::string> dtypes = input->dtypes();
  291. std::vector<std::string> formats = input->formats();
  292. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  293. MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size != formats size. dtypes size: "
  294. << dtypes.size() << ", formats size : " << formats.size();
  295. return false;
  296. }
  297. if (param_type == "dynamic") {
  298. if (dyn_input_sizes.empty()) {
  299. MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
  300. return false;
  301. }
  302. for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  303. kernel_info_index++;
  304. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  305. inputs_device_type.push_back(type_id);
  306. inputs_format.push_back(formats[builder_idex]);
  307. }
  308. dyn_input_idx++;
  309. } else if (param_type == "required") {
  310. kernel_info_index++;
  311. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  312. inputs_device_type.push_back(type_id);
  313. inputs_format.push_back(formats[builder_idex]);
  314. } else {
  315. if (kernel_info_index < real_input_num) {
  316. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  317. kernel_info_index++;
  318. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  319. inputs_device_type.push_back(type_id);
  320. inputs_format.push_back(formats[builder_idex]);
  321. }
  322. }
  323. }
  324. builder->SetInputsDeviceType(inputs_device_type);
  325. builder->SetInputsFormat(inputs_format);
  326. return true;
  327. }
  328. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  329. const size_t &real_output_num,
  330. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  331. // not now but in the next we need to support dynamic output case
  332. MS_EXCEPTION_IF_NULL(builder);
  333. size_t output_idx = 0;
  334. std::vector<TypeId> outputs_device_type;
  335. std::vector<std::string> outputs_format;
  336. MS_EXCEPTION_IF_NULL(outputs[0]);
  337. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  338. for (const auto &output : outputs) {
  339. MS_EXCEPTION_IF_NULL(output);
  340. if (output_idx >= real_output_num) {
  341. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  342. continue;
  343. }
  344. size_t output_num = 0;
  345. if (output->param_type() == "dynamic") {
  346. if (outputs.size() > 1) {
  347. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  348. }
  349. output_num = real_output_num;
  350. } else if (output->param_type() == "required") {
  351. output_num = 1;
  352. } else {
  353. if (output_idx < real_output_num) {
  354. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  355. output_num = 1;
  356. }
  357. }
  358. for (size_t i = 0; i < output_num; i++) {
  359. std::vector<std::string> dtypes = output->dtypes();
  360. std::vector<std::string> formats = output->formats();
  361. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  362. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  363. return false;
  364. }
  365. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  366. outputs_device_type.push_back(type_id);
  367. outputs_format.push_back(formats[builder_idex]);
  368. output_idx++;
  369. }
  370. }
  371. builder->SetOutputsFormat(outputs_format);
  372. builder->SetOutputsDeviceType(outputs_device_type);
  373. return true;
  374. }
  375. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  376. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  377. MS_EXCEPTION_IF_NULL(builder);
  378. MS_EXCEPTION_IF_NULL(op_info_ptr);
  379. auto imply_type = op_info_ptr->imply_type();
  380. builder->SetProcessor(processor);
  381. std::string fusion_name = op_info_ptr->fusion_type();
  382. auto fusion_type = GetFusionTypeByName(fusion_name);
  383. builder->SetFusionType(fusion_type);
  384. if (imply_type == kAKG) {
  385. builder->SetKernelType(AKG_KERNEL);
  386. } else if (imply_type == kAICPU) {
  387. builder->SetKernelType(AICPU_KERNEL);
  388. } else {
  389. builder->SetKernelType(TBE_KERNEL);
  390. }
  391. }
  392. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  393. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  394. MS_EXCEPTION_IF_NULL(kernel_node);
  395. MS_EXCEPTION_IF_NULL(kernel_info_list);
  396. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  397. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  398. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  399. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  400. std::vector<int64_t> dyn_input_sizes;
  401. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  402. MS_EXCEPTION_IF_NULL(primitive);
  403. auto op_name = AnfAlgo::GetCNodeName(kernel_node);
  404. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  405. dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
  406. }
  407. if (inputs.size() > 0) {
  408. if (inputs[0] == nullptr) {
  409. MS_LOG(EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
  410. }
  411. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  412. for (size_t j = 0; j < kernel_info_cnt; j++) {
  413. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  414. MS_EXCEPTION_IF_NULL(builder);
  415. SetKernelBuildInfo(builder, processor, op_info_ptr);
  416. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  417. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
  418. return false;
  419. }
  420. if (outputs.size() > 0) {
  421. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  422. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  423. return false;
  424. }
  425. }
  426. kernel_info_list->push_back(builder->Build());
  427. }
  428. } else if (outputs.size() > 0) {
  429. if (outputs[0] == nullptr) {
  430. MS_LOG(EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
  431. }
  432. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  433. for (size_t j = 0; j < kernel_info_cnt; j++) {
  434. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  435. MS_EXCEPTION_IF_NULL(builder);
  436. SetKernelBuildInfo(builder, processor, op_info_ptr);
  437. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  438. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  439. return false;
  440. }
  441. kernel_info_list->push_back(builder->Build());
  442. }
  443. } else {
  444. if (processor == AICPU) {
  445. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  446. MS_EXCEPTION_IF_NULL(builder);
  447. SetKernelBuildInfo(builder, processor, op_info_ptr);
  448. kernel_info_list->push_back(builder->Build());
  449. }
  450. }
  451. return true;
  452. }
  453. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  454. std::string path = base_path + json_name + kInfoSuffix;
  455. auto realpath = Common::CreatePrefixPath(path);
  456. if (!realpath.has_value()) {
  457. MS_LOG(ERROR) << "Get real path failed, path=" << path;
  458. return;
  459. }
  460. ChangeFileMode(realpath.value(), S_IWUSR);
  461. std::ofstream filewrite(realpath.value());
  462. if (!filewrite.is_open()) {
  463. MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
  464. return;
  465. }
  466. filewrite << info << std::endl;
  467. filewrite.close();
  468. ChangeFileMode(realpath.value(), S_IRUSR);
  469. }
  470. Processor GetProcessor(const string &processor) {
  471. if (processor == kProcessorAiCore) return Processor::AICORE;
  472. if (processor == kProcessorAiCpu) return Processor::AICPU;
  473. if (processor == kProcessorCuda) return Processor::CUDA;
  474. MS_LOG(DEBUG) << "Unknown processor type.";
  475. return Processor::UNKNOWN;
  476. }
  477. std::string GetProcessor(const AnfNodePtr &anf_node) {
  478. MS_EXCEPTION_IF_NULL(anf_node);
  479. std::string device;
  480. switch (AnfAlgo::GetProcessor(anf_node)) {
  481. case Processor::AICORE:
  482. device = kProcessorAiCore;
  483. break;
  484. case Processor::AICPU:
  485. device = kProcessorAiCpu;
  486. break;
  487. case Processor::CUDA:
  488. device = kProcessorCuda;
  489. break;
  490. default:
  491. MS_LOG(DEBUG) << "Unknown processor type.";
  492. break;
  493. }
  494. return device;
  495. }
  496. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  497. if (shape_a.size() != shape_b.size()) {
  498. return false;
  499. }
  500. for (size_t i = 0; i < shape_a.size(); ++i) {
  501. if (shape_a[i] != shape_b[i]) {
  502. return false;
  503. }
  504. }
  505. return true;
  506. }
  507. int Sign(float x) {
  508. if (x > 0) {
  509. return 1;
  510. }
  511. if (x < 0) {
  512. return -1;
  513. }
  514. return 0;
  515. }
  516. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  517. MS_EXCEPTION_IF_NULL(anf_node);
  518. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  519. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs. Node info : ["
  520. << anf_node->DebugString() << "]";
  521. }
  522. auto cnode = anf_node->cast<CNodePtr>();
  523. if (cnode == nullptr) {
  524. return AnfAlgo::VisitKernel(anf_node, 0);
  525. } else {
  526. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  527. }
  528. }
  529. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  530. const std::vector<AnfNodePtr> &input_list) {
  531. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  532. for (size_t i = 0; i < input_list.size(); ++i) {
  533. auto const &input = input_list[i];
  534. MS_EXCEPTION_IF_NULL(input);
  535. bool found = false;
  536. auto mng = input->func_graph()->manager();
  537. MS_EXCEPTION_IF_NULL(mng);
  538. const NodeUsersMap &users = mng->node_users();
  539. auto input_users = users.find(input);
  540. if (input_users == users.end() || input_users->second.empty()) {
  541. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  542. << input->func_graph()->ToString() << "] has no users.";
  543. }
  544. for (auto const &input_user : input_users->second) {
  545. for (auto const &anf_node : node_list) {
  546. if (anf_node != input_user.first) {
  547. continue;
  548. }
  549. std::vector<int64_t> dyn_input_sizes;
  550. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  551. MS_EXCEPTION_IF_NULL(prim);
  552. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  553. dyn_input_sizes = GetValue<const std::vector<int64_t>>(prim->GetAttr(kAttrDynInputSizes));
  554. }
  555. if (dyn_input_sizes.empty()) {
  556. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  557. found = true;
  558. break;
  559. } else {
  560. int used_as_idx = input_user.second - 1;
  561. int accum_idx = 0;
  562. size_t dyn_i = 0;
  563. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  564. accum_idx += LongToInt(dyn_input_sizes[dyn_i]);
  565. if (used_as_idx < accum_idx) {
  566. input_index.push_back(std::make_pair(
  567. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  568. break;
  569. }
  570. }
  571. if (dyn_i != dyn_input_sizes.size()) {
  572. found = true;
  573. break;
  574. }
  575. }
  576. }
  577. if (found) {
  578. break;
  579. }
  580. }
  581. if (!found) {
  582. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  583. << input->func_graph()->ToString() << "] found no related kernel info.";
  584. }
  585. }
  586. return input_index;
  587. }
  588. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  589. const std::vector<AnfNodePtr> &input_list,
  590. const std::vector<AnfNodePtr> &output_list) {
  591. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  592. for (size_t i = 0; i < output_list.size(); ++i) {
  593. auto const &output = output_list[i];
  594. MS_EXCEPTION_IF_NULL(output);
  595. bool found = false;
  596. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  597. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  598. if (pos != std::end(node_list)) {
  599. output_index.push_back(pree_node);
  600. continue;
  601. }
  602. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  603. if (ret != std::end(input_list)) {
  604. output_index.push_back(std::make_pair(pree_node.first, 0));
  605. found = true;
  606. }
  607. if (!found) {
  608. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  609. << output->func_graph()->ToString() << "] found no related kernel info.";
  610. }
  611. }
  612. return output_index;
  613. }
  614. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  615. MS_EXCEPTION_IF_NULL(node_list);
  616. MS_EXCEPTION_IF_NULL(func_graph);
  617. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  618. for (auto const &node : node_lists) {
  619. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  620. continue;
  621. }
  622. auto cnode = node->cast<CNodePtr>();
  623. MS_EXCEPTION_IF_NULL(cnode);
  624. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  625. node_list->push_back(node);
  626. }
  627. }
  628. }
  629. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  630. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  631. MS_EXCEPTION_IF_NULL(func_graph);
  632. MS_EXCEPTION_IF_NULL(node_list);
  633. MS_EXCEPTION_IF_NULL(input_list);
  634. GetValidKernelNodes(func_graph, node_list);
  635. auto parameters = func_graph->parameters();
  636. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  637. GetFuncGraphOutputNodes(func_graph, output_list);
  638. }
  639. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  640. MS_EXCEPTION_IF_NULL(func_graph);
  641. MS_EXCEPTION_IF_NULL(output_list);
  642. auto func_output = func_graph->output();
  643. MS_EXCEPTION_IF_NULL(func_output);
  644. if (func_output->isa<CNode>()) {
  645. // multi output.
  646. auto cnode = func_output->cast<CNodePtr>();
  647. MS_EXCEPTION_IF_NULL(cnode);
  648. auto input0 = cnode->input(kAnfPrimitiveIndex);
  649. MS_EXCEPTION_IF_NULL(input0);
  650. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  651. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  652. auto input_node = cnode->input(input_idx);
  653. MS_EXCEPTION_IF_NULL(input_node);
  654. if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
  655. continue;
  656. }
  657. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  658. }
  659. } else {
  660. // single output.
  661. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  662. }
  663. } else {
  664. // single output.
  665. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  666. }
  667. }
  668. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  669. MS_EXCEPTION_IF_NULL(anf_node);
  670. MS_EXCEPTION_IF_NULL(node_json);
  671. auto cnode = anf_node->cast<CNodePtr>();
  672. MS_EXCEPTION_IF_NULL(cnode);
  673. if (input_idx + 1 >= cnode->size()) {
  674. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  675. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  676. }
  677. auto input_node = cnode->input(input_idx + 1);
  678. if (!IsValueNode<tensor::Tensor>(input_node)) {
  679. return false;
  680. }
  681. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  682. if (tensor == nullptr) {
  683. MS_LOG(DEBUG) << "Value of input node is nullptr, op: [" << input_node->DebugString() << "]";
  684. return false;
  685. }
  686. auto type_id = tensor->data_type();
  687. auto *data = tensor->data_c();
  688. MS_EXCEPTION_IF_NULL(data);
  689. if (tensor->DataSize() > 1) {
  690. // not const tensor.
  691. MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
  692. return false;
  693. }
  694. if (type_id == kFloat64->type_id()) {
  695. (*node_json)["value"] = static_cast<double *>(data)[0];
  696. } else if (type_id == kFloat32->type_id()) {
  697. (*node_json)["value"] = static_cast<float *>(data)[0];
  698. } else if (type_id == kFloat16->type_id()) {
  699. float16 *val = static_cast<float16 *>(data);
  700. (*node_json)["value"] = static_cast<float>(val[0]);
  701. } else if (type_id == kUInt64->type_id()) {
  702. (*node_json)["value"] = static_cast<uint64_t *>(data)[0];
  703. } else if (type_id == kUInt32->type_id()) {
  704. (*node_json)["value"] = static_cast<uint32_t *>(data)[0];
  705. } else if (type_id == kUInt16->type_id()) {
  706. (*node_json)["value"] = static_cast<uint16_t *>(data)[0];
  707. } else if (type_id == kUInt8->type_id()) {
  708. (*node_json)["value"] = static_cast<uint8_t *>(data)[0];
  709. } else if (type_id == kInt64->type_id()) {
  710. (*node_json)["value"] = static_cast<int64_t *>(data)[0];
  711. } else if (type_id == kInt32->type_id()) {
  712. (*node_json)["value"] = static_cast<int32_t *>(data)[0];
  713. } else if (type_id == kInt16->type_id()) {
  714. (*node_json)["value"] = static_cast<int16_t *>(data)[0];
  715. } else if (type_id == kInt8->type_id()) {
  716. (*node_json)["value"] = static_cast<int8_t *>(data)[0];
  717. } else if (type_id == kBool->type_id()) {
  718. (*node_json)["value"] = static_cast<bool *>(data)[0];
  719. } else {
  720. MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  721. }
  722. return true;
  723. }
  724. bool IsWeightBoundary(const AnfNodePtr &node) {
  725. if (node->isa<ValueNode>()) {
  726. return true;
  727. }
  728. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  729. return true;
  730. }
  731. return false;
  732. }
  733. std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
  734. if (AnfAlgo::GetInputTensorNum(cnode) != 1 || AnfAlgo::GetOutputTensorNum(cnode) != 1) {
  735. MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output.";
  736. }
  737. std::vector<int64_t> axis;
  738. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  739. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  740. MS_EXCEPTION_IF_NULL(primitive);
  741. auto axis_attr = primitive->GetAttr(kAxis);
  742. if (axis_attr == nullptr) {
  743. MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
  744. return std::vector<int64_t>();
  745. }
  746. std::vector<int64_t> axis_list;
  747. if (axis_attr->isa<Int64Imm>()) {
  748. (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
  749. } else {
  750. axis_list = GetValue<std::vector<int64_t>>(axis_attr);
  751. }
  752. for (const auto &elem : axis_list) {
  753. if (elem < 0) {
  754. (void)axis.emplace_back(input_shape.size() + elem);
  755. } else {
  756. (void)axis.emplace_back(elem);
  757. }
  758. }
  759. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  760. return axis;
  761. }
  762. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  763. MS_EXCEPTION_IF_NULL(anf_node);
  764. std::string processor = kProcessorUnknown;
  765. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  766. MS_EXCEPTION_IF_NULL(kernel_info);
  767. auto build_info = kernel_info->select_kernel_build_info();
  768. // we may call this before kernel select.
  769. if (build_info == nullptr) {
  770. return processor;
  771. }
  772. switch (build_info->processor()) {
  773. case Processor::AICORE:
  774. processor = kProcessorAiCore;
  775. break;
  776. case Processor::AICPU:
  777. processor = kProcessorAiCpu;
  778. break;
  779. case Processor::CUDA:
  780. processor = kProcessorCuda;
  781. break;
  782. default:
  783. MS_LOG(ERROR) << "Unknown processor type.";
  784. break;
  785. }
  786. return processor;
  787. }
  788. Processor GetProcessorFromContext() {
  789. kernel::Processor processor = kernel::Processor::UNKNOWN;
  790. auto context_ptr = MsContext::GetInstance();
  791. MS_EXCEPTION_IF_NULL(context_ptr);
  792. auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  793. if (device_info == kGPUDevice) {
  794. processor = kernel::Processor::CUDA;
  795. } else if (device_info == kAscendDevice) {
  796. processor = kernel::Processor::AICORE;
  797. }
  798. return processor;
  799. }
  800. std::string GetStrProcessorFromContext() {
  801. auto processor = GetProcessorFromContext();
  802. string str_processor = kernel::kProcessorUnknown;
  803. if (processor == kernel::Processor::CUDA) {
  804. str_processor = kernel::kProcessorCuda;
  805. } else if (processor == kernel::Processor::AICORE) {
  806. str_processor = kernel::kProcessorAiCore;
  807. }
  808. return str_processor;
  809. }
  810. float Scaling(size_t in_size, size_t out_size, bool align_corners) {
  811. return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
  812. : in_size / static_cast<float>(out_size);
  813. }
  814. float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
  815. void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
  816. CachedInterpolation *interpolation) {
  817. interpolation[out_size].lower = 0;
  818. interpolation[out_size].upper = 0;
  819. for (size_t i = 0; i <= out_size - 1; ++i) {
  820. const float in = ScaleGrid(i, scale);
  821. const float in_f = std::floor(in);
  822. interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
  823. interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
  824. interpolation[i].lerp = in - in_f;
  825. }
  826. }
  827. bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i) {
  828. MS_EXCEPTION_IF_NULL(type_ptr);
  829. size_t type_byte = GetTypeByte(type_ptr);
  830. if (type_byte == 0) {
  831. return false;
  832. }
  833. for (size_t j = 0; j < shape.size(); j++) {
  834. size_i[0] = LongMulWithOverflowCheck(size_i[0], static_cast<int>(shape[j]));
  835. }
  836. size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
  837. return true;
  838. }
  839. void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
  840. MS_EXCEPTION_IF_NULL(long_shape);
  841. std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
  842. }
  843. void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  844. const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
  845. if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
  846. MS_LOG(EXCEPTION)
  847. << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
  848. }
  849. size_t size = start.size();
  850. for (size_t i = 0; i < size; ++i) {
  851. if (stop[i] <= start[i]) {
  852. MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
  853. }
  854. // Operator need to be generalized in the future. Only support to copy continuous memory now.
  855. if (step[i] != 1) {
  856. MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
  857. }
  858. }
  859. size_t slice_pos = size;
  860. for (size_t i = 0; i < size; ++i) {
  861. if (stop[i] - start[i] > 1) {
  862. slice_pos = i;
  863. break;
  864. }
  865. }
  866. for (size_t i = slice_pos + 1; i < size; ++i) {
  867. if (stop[i] - start[i] != input_shape[i]) {
  868. MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
  869. "but tensor[0:100, 0] is not supported.";
  870. }
  871. }
  872. }
  873. size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
  874. const std::vector<int64_t> &stop) {
  875. for (size_t i = 0; i < start.size(); ++i) {
  876. if (stop[i] - start[i] != 1) {
  877. return SizetMulWithOverflowCheck(LongToSize(stop[i] - start[i]), LongToSize(dim_offset[i]));
  878. }
  879. }
  880. return LongToSize(dim_offset[start.size() - 1]);
  881. }
  882. std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
  883. std::vector<int64_t> dim_offset;
  884. int64_t offset = 1;
  885. for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
  886. dim_offset.push_back(offset);
  887. offset = offset * (*iter);
  888. }
  889. std::reverse(dim_offset.begin(), dim_offset.end());
  890. return dim_offset;
  891. }
  892. size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  893. const std::vector<int64_t> &dim_offset) {
  894. size_t size = start.size();
  895. size_t offset = 0;
  896. for (size_t i = 0; i < size; ++i) {
  897. offset += SizetMulWithOverflowCheck(LongToSize(dim_offset[i]), start[i]);
  898. if (stop[i] - start[i] != 1) {
  899. break;
  900. }
  901. }
  902. return offset;
  903. }
  904. size_t UnitSizeInBytes(const mindspore::TypeId &t) {
  905. size_t bytes = 0;
  906. switch (t) {
  907. case kNumberTypeBool:
  908. case kNumberTypeInt8:
  909. case kNumberTypeUInt8:
  910. bytes = sizeof(int8_t);
  911. break;
  912. case kNumberTypeInt16:
  913. case kNumberTypeUInt16:
  914. case kNumberTypeFloat16:
  915. bytes = sizeof(int16_t);
  916. break;
  917. case kNumberTypeInt:
  918. case kNumberTypeUInt:
  919. case kNumberTypeInt32:
  920. case kNumberTypeUInt32:
  921. case kNumberTypeFloat:
  922. case kNumberTypeFloat32:
  923. bytes = sizeof(int32_t);
  924. break;
  925. case kNumberTypeUInt64:
  926. case kNumberTypeInt64:
  927. case kNumberTypeFloat64:
  928. bytes = sizeof(int64_t);
  929. break;
  930. default:
  931. MS_LOG(EXCEPTION) << "Invalid types " << t;
  932. break;
  933. }
  934. return bytes;
  935. }
  936. } // namespace kernel
  937. } // namespace mindspore