You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 38 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_utils.h"
  27. #include "ir/manager.h"
  28. #include "ir/meta_tensor.h"
  29. #include "base/core_ops.h"
  30. #include "ir/graph_utils.h"
  31. #include "utils/ms_context.h"
  32. #include "mindspore/ccsrc/debug/common.h"
  33. namespace mindspore {
  34. namespace kernel {
  35. constexpr char kAxis[] = "axis";
  36. constexpr char kTypeInt32[] = "Int32";
  37. const std::unordered_map<std::string, TypeId> type_id_maps = {
  38. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  39. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  40. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  41. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  42. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  43. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  44. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  45. {"bool", TypeId::kNumberTypeBool}, {"complex64", TypeId::kNumberTypeComplex64}};
  46. const std::map<TypeId, std::string> type_id_str_map = {
  47. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  48. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  49. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  50. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  51. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  52. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  53. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  54. {TypeId::kNumberTypeBool, "bool"}, {TypeId::kNumberTypeComplex64, "complex64"}};
  55. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  56. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  57. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  58. };
  59. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  60. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  61. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  62. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  63. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  64. {"complex64", sizeof(float) * 2}};
  65. // Define all patterns here for different schedule
  66. const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
  67. {FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
  68. {FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
  69. {FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
  70. {FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
  71. {FusionType::ELEMWISE, "ElemWise"},
  72. {FusionType::PURE_BROADCAST, "PureBroadcast"},
  73. {FusionType::COMMREDUCE, "CommReduce"},
  74. {FusionType::SEGMENT, "Segment"},
  75. {FusionType::INPLACE, "Inplace"},
  76. {FusionType::MATMUL, "Matmul"},
  77. {FusionType::MATMUL_V2, "Matmul_v2"},
  78. {FusionType::GEMM, "GEMM"},
  79. {FusionType::CONV, "Convolution"},
  80. {FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
  81. {FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
  82. {FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
  83. {FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
  84. {FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
  85. {FusionType::OPAQUE, "Opaque"},
  86. {FusionType::BN_REDUCE, "bn_reduce"},
  87. {FusionType::BN_UPDATE, "bn_update"},
  88. {FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
  89. {FusionType::L2_NORMALIZE, "l2_normalize"},
  90. {FusionType::SOFTMAX, "softmax_pattern"},
  91. {FusionType::L2_LOSS, "l2_loss"},
  92. {FusionType::ASCEND_QUANT, "quant"},
  93. {FusionType::ASCEND_DEQUANT, "dequant"},
  94. {FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
  95. {FusionType::STRIDED_READ, "strided_read"},
  96. {FusionType::STRIDED_WRITE, "strided_write"},
  97. {FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
  98. {FusionType::ASCEND_REQUANT, "requant"},
  99. {FusionType::ASCEND_REQUANT_S16, "requant_s16"},
  100. {FusionType::MAX_POOL, "MaxPool"},
  101. {FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
  102. {FusionType::CONV3D, "Conv3d"},
  103. {FusionType::POOL2D, "Pool2d"},
  104. {FusionType::POOL3D, "Pool3d"},
  105. {FusionType::READ_SELECT, "read_select"},
  106. {FusionType::WRITE_SELECT, "write_select"},
  107. {FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
  108. {FusionType::DILATION_PATTERN, "dilation"},
  109. {FusionType::BROAD_CAST, "Broadcast"},
  110. {FusionType::BATCH_MATMUL, "BatchMatmul"},
  111. {FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
  112. {FusionType::UNKNOWN_FUSION_TYPE, ""}};
  113. std::string GetFusionNameByType(const kernel::FusionType &type) {
  114. auto iter = fusion_type_name_maps.find(type);
  115. if (iter == fusion_type_name_maps.end()) {
  116. MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
  117. }
  118. return iter->second;
  119. }
  120. FusionType GetFusionTypeByName(const std::string &name) {
  121. std::string fusion_name_upper = name;
  122. transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
  123. auto iter =
  124. std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
  125. std::string name_upper = it.second;
  126. transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
  127. return fusion_name_upper == name_upper;
  128. });
  129. if (iter == fusion_type_name_maps.end()) {
  130. MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
  131. }
  132. return iter->first;
  133. }
  134. void KernelMeta::Initialize() {
  135. kernel_meta_path_ = std::string(kGpuKernelMeta) + "/";
  136. #if defined(_WIN32) || defined(_WIN64)
  137. auto ret = mkdir(kernel_meta_path_.c_str());
  138. #else
  139. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  140. #endif
  141. if (ret != 0) {
  142. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  143. }
  144. initialized_ = true;
  145. }
  146. std::string KernelMeta::Search(const std::string &kernel_name) const {
  147. if (!initialized_) {
  148. return "";
  149. }
  150. auto iter = kernel_meta_map_.find(kernel_name);
  151. if (iter == kernel_meta_map_.end()) {
  152. return "";
  153. } else {
  154. return iter->second;
  155. }
  156. }
  157. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  158. if (!initialized_) {
  159. return false;
  160. }
  161. kernel_meta_map_[kernel_name] = kernel_json;
  162. return true;
  163. }
  164. bool CheckCache(const std::string &kernel_name) {
  165. // check cache.
  166. KernelMeta *bin_map = KernelMeta::GetInstance();
  167. if (bin_map == nullptr) {
  168. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel_name: " << kernel_name;
  169. return false;
  170. }
  171. std::string kernel_json = bin_map->Search(kernel_name);
  172. bool ret = (!kernel_json.empty());
  173. if (ret) {
  174. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registered.";
  175. } else {
  176. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registered.";
  177. }
  178. return ret;
  179. }
  180. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  181. // search cache.
  182. KernelMeta *bin_map = KernelMeta::GetInstance();
  183. if (bin_map == nullptr) {
  184. MS_LOG(DEBUG) << "kernel cache is invalid, kernel_name: " << kernel_name;
  185. return nullptr;
  186. }
  187. std::string kernel_json = bin_map->Search(kernel_name);
  188. if (!kernel_json.empty()) {
  189. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  190. // just a tmp solution.
  191. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  192. MS_LOG(ERROR) << "Read cache json and bin file failed[" << kernel_json << "].";
  193. return nullptr;
  194. } else {
  195. return kernel_pack;
  196. }
  197. } else {
  198. MS_LOG(INFO) << "The cache kernel not found[" << kernel_name << "].";
  199. return nullptr;
  200. }
  201. }
  202. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  203. MS_LOG(INFO) << "Insert cache for kernel:" << kernel_name << ", processr:" << processor;
  204. KernelMeta *bin_map = KernelMeta::GetInstance();
  205. std::string kernel_json;
  206. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  207. kernel_json = kCceKernelMeta;
  208. } else {
  209. kernel_json = bin_map->kernel_meta_path();
  210. }
  211. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  212. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  213. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  214. MS_LOG(ERROR) << "Read json and bin file failed[" << kernel_json << "].";
  215. return nullptr;
  216. }
  217. if (bin_map == nullptr) {
  218. MS_LOG(DEBUG) << "Kernel cache is invalid, kernel name :" << kernel_name;
  219. return nullptr;
  220. }
  221. if (bin_map->Insert(kernel_name, kernel_json)) {
  222. MS_LOG(INFO) << "Kernel insert cache success[" << kernel_json << "], kernel name[" << kernel_name << "].";
  223. }
  224. return kernel_pack;
  225. }
  226. TypeId DtypeToTypeId(const std::string &dtypes) {
  227. auto iter = type_id_maps.find(dtypes);
  228. if (iter != type_id_maps.end()) {
  229. return iter->second;
  230. } else {
  231. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  232. }
  233. }
  234. std::string TypeId2String(TypeId type_id, bool unknown_as_default) {
  235. auto iter = type_id_str_map.find(type_id);
  236. if (iter == type_id_str_map.end()) {
  237. if (!unknown_as_default) {
  238. MS_EXCEPTION(ArgumentError) << "Illegal input dtype." << TypeIdLabel(type_id);
  239. }
  240. MS_LOG(INFO) << "Using default dtype: float32";
  241. return "float32";
  242. }
  243. return iter->second;
  244. }
  245. std::string Dtype2ShortType(const std::string &dtype) {
  246. auto iter = dtype_shortdtype_map_.find(dtype);
  247. if (iter != dtype_shortdtype_map_.end()) {
  248. return iter->second;
  249. } else {
  250. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  251. }
  252. }
  253. size_t GetDtypeNbyte(const std::string &dtype) {
  254. auto iter = dtype_nbyte_map.find(dtype);
  255. if (iter != dtype_nbyte_map.end()) {
  256. return iter->second;
  257. } else {
  258. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtype;
  259. }
  260. }
  261. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  262. size_t builder_idex, const std::vector<int64_t> &dyn_input_sizes,
  263. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  264. MS_EXCEPTION_IF_NULL(builder);
  265. std::vector<TypeId> inputs_device_type;
  266. std::vector<std::string> inputs_format;
  267. size_t dyn_input_idx = 0;
  268. size_t kernel_info_index = 0;
  269. MS_EXCEPTION_IF_NULL(inputs[0]);
  270. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  271. for (const auto &input : inputs) {
  272. MS_EXCEPTION_IF_NULL(input);
  273. std::string param_type = input->param_type();
  274. std::vector<std::string> dtypes = input->dtypes();
  275. std::vector<std::string> formats = input->formats();
  276. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  277. MS_LOG(DEBUG) << "Set input kernel builder info failed, dtyps size != formats size. dtypes size: "
  278. << dtypes.size() << ", formats size : " << formats.size();
  279. return false;
  280. }
  281. if (param_type == "dynamic") {
  282. if (dyn_input_sizes.empty()) {
  283. MS_LOG(DEBUG) << "Set input kernel builder info failed, dyn_input_sizes's size is 0 when param_type is dynamic";
  284. return false;
  285. }
  286. for (int64_t t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  287. kernel_info_index++;
  288. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  289. inputs_device_type.push_back(type_id);
  290. inputs_format.push_back(formats[builder_idex]);
  291. }
  292. dyn_input_idx++;
  293. } else if (param_type == "required") {
  294. kernel_info_index++;
  295. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  296. inputs_device_type.push_back(type_id);
  297. inputs_format.push_back(formats[builder_idex]);
  298. } else {
  299. if (kernel_info_index < real_input_num) {
  300. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  301. kernel_info_index++;
  302. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  303. inputs_device_type.push_back(type_id);
  304. inputs_format.push_back(formats[builder_idex]);
  305. }
  306. }
  307. }
  308. builder->SetInputsDeviceType(inputs_device_type);
  309. builder->SetInputsFormat(inputs_format);
  310. return true;
  311. }
  312. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  313. const size_t &real_output_num,
  314. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  315. // not now but in the next we need to support dynamic output case
  316. MS_EXCEPTION_IF_NULL(builder);
  317. size_t output_idx = 0;
  318. std::vector<TypeId> outputs_device_type;
  319. std::vector<std::string> outputs_format;
  320. MS_EXCEPTION_IF_NULL(outputs[0]);
  321. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  322. for (const auto &output : outputs) {
  323. MS_EXCEPTION_IF_NULL(output);
  324. if (output_idx >= real_output_num) {
  325. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  326. continue;
  327. }
  328. size_t output_num = 0;
  329. if (output->param_type() == "dynamic") {
  330. if (outputs.size() > 1) {
  331. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  332. }
  333. output_num = real_output_num;
  334. } else if (output->param_type() == "required") {
  335. output_num = 1;
  336. } else {
  337. if (output_idx < real_output_num) {
  338. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  339. output_num = 1;
  340. }
  341. }
  342. for (size_t i = 0; i < output_num; i++) {
  343. std::vector<std::string> dtypes = output->dtypes();
  344. std::vector<std::string> formats = output->formats();
  345. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  346. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  347. return false;
  348. }
  349. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  350. outputs_device_type.push_back(type_id);
  351. outputs_format.push_back(formats[builder_idex]);
  352. output_idx++;
  353. }
  354. }
  355. builder->SetOutputsFormat(outputs_format);
  356. builder->SetOutputsDeviceType(outputs_device_type);
  357. return true;
  358. }
  359. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  360. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  361. MS_EXCEPTION_IF_NULL(builder);
  362. MS_EXCEPTION_IF_NULL(op_info_ptr);
  363. auto imply_type = op_info_ptr->imply_type();
  364. builder->SetProcessor(processor);
  365. std::string fusion_name = op_info_ptr->fusion_type();
  366. auto fusion_type = GetFusionTypeByName(fusion_name);
  367. builder->SetFusionType(fusion_type);
  368. if (imply_type == kAKG) {
  369. builder->SetKernelType(AKG_KERNEL);
  370. } else if (imply_type == kAICPU) {
  371. builder->SetKernelType(AICPU_KERNEL);
  372. } else {
  373. builder->SetKernelType(TBE_KERNEL);
  374. }
  375. }
  376. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  377. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  378. MS_EXCEPTION_IF_NULL(kernel_node);
  379. MS_EXCEPTION_IF_NULL(kernel_info_list);
  380. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  381. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  382. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  383. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  384. std::vector<int64_t> dyn_input_sizes;
  385. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  386. MS_EXCEPTION_IF_NULL(primitive);
  387. auto op_name = AnfAlgo::GetCNodeName(kernel_node);
  388. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  389. dyn_input_sizes = GetValue<std::vector<int64_t>>(primitive->GetAttr("dyn_input_sizes"));
  390. }
  391. if (inputs.size() > 0) {
  392. if (inputs[0] == nullptr) {
  393. MS_LOG(EXCEPTION) << "Inputs[0] is nullptr. Op name: " << op_name;
  394. }
  395. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  396. for (size_t j = 0; j < kernel_info_cnt; j++) {
  397. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  398. MS_EXCEPTION_IF_NULL(builder);
  399. SetKernelBuildInfo(builder, processor, op_info_ptr);
  400. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  401. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed. Op name: " << op_name;
  402. return false;
  403. }
  404. if (outputs.size() > 0) {
  405. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  406. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  407. return false;
  408. }
  409. }
  410. kernel_info_list->push_back(builder->Build());
  411. }
  412. } else if (outputs.size() > 0) {
  413. if (outputs[0] == nullptr) {
  414. MS_LOG(EXCEPTION) << "Outputs[0] is nullptr. Op name: " << op_name;
  415. }
  416. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  417. for (size_t j = 0; j < kernel_info_cnt; j++) {
  418. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  419. MS_EXCEPTION_IF_NULL(builder);
  420. SetKernelBuildInfo(builder, processor, op_info_ptr);
  421. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  422. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed. Op name: " << op_name;
  423. return false;
  424. }
  425. kernel_info_list->push_back(builder->Build());
  426. }
  427. } else {
  428. if (processor == AICPU) {
  429. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  430. MS_EXCEPTION_IF_NULL(builder);
  431. SetKernelBuildInfo(builder, processor, op_info_ptr);
  432. kernel_info_list->push_back(builder->Build());
  433. }
  434. }
  435. return true;
  436. }
  437. void SaveJsonInfo(const std::string &json_name, const std::string &info, const std::string &base_path) {
  438. std::string path = base_path + json_name + kInfoSuffix;
  439. auto realpath = Common::GetRealPath(path);
  440. if (!realpath.has_value()) {
  441. MS_LOG(ERROR) << "Get real path failed, path=" << path;
  442. return;
  443. }
  444. ChangeFileMode(realpath.value(), S_IWUSR);
  445. std::ofstream filewrite(realpath.value());
  446. if (!filewrite.is_open()) {
  447. MS_LOG(ERROR) << "Open file '" << realpath.value() << "' failed!";
  448. return;
  449. }
  450. filewrite << info << std::endl;
  451. filewrite.close();
  452. ChangeFileMode(realpath.value(), S_IRUSR);
  453. }
  454. Processor GetProcessor(const string &processor) {
  455. if (processor == kProcessorAiCore) return Processor::AICORE;
  456. if (processor == kProcessorAiCpu) return Processor::AICPU;
  457. if (processor == kProcessorCuda) return Processor::CUDA;
  458. MS_LOG(DEBUG) << "Unknown processor type.";
  459. return Processor::UNKNOWN;
  460. }
  461. std::string GetProcessor(const AnfNodePtr &anf_node) {
  462. MS_EXCEPTION_IF_NULL(anf_node);
  463. std::string device;
  464. switch (AnfAlgo::GetProcessor(anf_node)) {
  465. case Processor::AICORE:
  466. device = kProcessorAiCore;
  467. break;
  468. case Processor::AICPU:
  469. device = kProcessorAiCpu;
  470. break;
  471. case Processor::CUDA:
  472. device = kProcessorCuda;
  473. break;
  474. default:
  475. MS_LOG(DEBUG) << "Unknown processor type.";
  476. break;
  477. }
  478. return device;
  479. }
  480. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  481. if (shape_a.size() != shape_b.size()) {
  482. return false;
  483. }
  484. for (size_t i = 0; i < shape_a.size(); ++i) {
  485. if (shape_a[i] != shape_b[i]) {
  486. return false;
  487. }
  488. }
  489. return true;
  490. }
  491. int Sign(float x) {
  492. if (x > 0) {
  493. return 1;
  494. }
  495. if (x < 0) {
  496. return -1;
  497. }
  498. return 0;
  499. }
  500. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  501. MS_EXCEPTION_IF_NULL(anf_node);
  502. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  503. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs. Node info : ["
  504. << anf_node->DebugString() << "]";
  505. }
  506. auto cnode = anf_node->cast<CNodePtr>();
  507. if (cnode == nullptr) {
  508. return AnfAlgo::VisitKernel(anf_node, 0);
  509. } else {
  510. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  511. }
  512. }
  513. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  514. const std::vector<AnfNodePtr> &input_list) {
  515. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  516. for (size_t i = 0; i < input_list.size(); ++i) {
  517. auto const &input = input_list[i];
  518. MS_EXCEPTION_IF_NULL(input);
  519. bool found = false;
  520. auto mng = input->func_graph()->manager();
  521. MS_EXCEPTION_IF_NULL(mng);
  522. const NodeUsersMap &users = mng->node_users();
  523. auto input_users = users.find(input);
  524. if (input_users == users.end() || input_users->second.empty()) {
  525. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  526. << input->func_graph()->ToString() << "] has no users.";
  527. }
  528. for (auto const &input_user : input_users->second) {
  529. for (auto const &anf_node : node_list) {
  530. if (anf_node != input_user.first) {
  531. continue;
  532. }
  533. std::vector<int64_t> dyn_input_sizes;
  534. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  535. MS_EXCEPTION_IF_NULL(prim);
  536. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  537. dyn_input_sizes = GetValue<const std::vector<int64_t>>(prim->GetAttr(kAttrDynInputSizes));
  538. }
  539. if (dyn_input_sizes.empty()) {
  540. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  541. found = true;
  542. break;
  543. } else {
  544. int used_as_idx = input_user.second - 1;
  545. int accum_idx = 0;
  546. size_t dyn_i = 0;
  547. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  548. accum_idx += LongToInt(dyn_input_sizes[dyn_i]);
  549. if (used_as_idx < accum_idx) {
  550. input_index.push_back(std::make_pair(
  551. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  552. break;
  553. }
  554. }
  555. if (dyn_i != dyn_input_sizes.size()) {
  556. found = true;
  557. break;
  558. }
  559. }
  560. }
  561. if (found) {
  562. break;
  563. }
  564. }
  565. if (!found) {
  566. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  567. << input->func_graph()->ToString() << "] found no related kernel info.";
  568. }
  569. }
  570. return input_index;
  571. }
  572. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  573. const std::vector<AnfNodePtr> &input_list,
  574. const std::vector<AnfNodePtr> &output_list) {
  575. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  576. for (size_t i = 0; i < output_list.size(); ++i) {
  577. auto const &output = output_list[i];
  578. MS_EXCEPTION_IF_NULL(output);
  579. bool found = false;
  580. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  581. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  582. if (pos != std::end(node_list)) {
  583. output_index.push_back(pree_node);
  584. continue;
  585. }
  586. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  587. if (ret != std::end(input_list)) {
  588. output_index.push_back(std::make_pair(pree_node.first, 0));
  589. found = true;
  590. }
  591. if (!found) {
  592. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  593. << output->func_graph()->ToString() << "] found no related kernel info.";
  594. }
  595. }
  596. return output_index;
  597. }
  598. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  599. MS_EXCEPTION_IF_NULL(node_list);
  600. MS_EXCEPTION_IF_NULL(func_graph);
  601. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  602. for (auto const &node : node_lists) {
  603. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  604. continue;
  605. }
  606. auto cnode = node->cast<CNodePtr>();
  607. MS_EXCEPTION_IF_NULL(cnode);
  608. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  609. node_list->push_back(node);
  610. }
  611. }
  612. }
  613. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  614. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  615. MS_EXCEPTION_IF_NULL(func_graph);
  616. MS_EXCEPTION_IF_NULL(node_list);
  617. MS_EXCEPTION_IF_NULL(input_list);
  618. GetValidKernelNodes(func_graph, node_list);
  619. auto parameters = func_graph->parameters();
  620. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  621. GetFuncGraphOutputNodes(func_graph, output_list);
  622. }
  623. void GetFuncGraphOutputNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *output_list) {
  624. MS_EXCEPTION_IF_NULL(func_graph);
  625. MS_EXCEPTION_IF_NULL(output_list);
  626. auto func_output = func_graph->output();
  627. MS_EXCEPTION_IF_NULL(func_output);
  628. if (func_output->isa<CNode>()) {
  629. // multi output.
  630. auto cnode = func_output->cast<CNodePtr>();
  631. MS_EXCEPTION_IF_NULL(cnode);
  632. auto input0 = cnode->input(kAnfPrimitiveIndex);
  633. MS_EXCEPTION_IF_NULL(input0);
  634. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  635. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  636. auto input_node = cnode->input(input_idx);
  637. MS_EXCEPTION_IF_NULL(input_node);
  638. if (input_node->isa<CNode>() && AnfAlgo::GetInputTensorNum(input_node) == 0) {
  639. continue;
  640. }
  641. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  642. }
  643. } else {
  644. // single output.
  645. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  646. }
  647. } else {
  648. // single output.
  649. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  650. }
  651. }
  652. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  653. MS_EXCEPTION_IF_NULL(anf_node);
  654. MS_EXCEPTION_IF_NULL(node_json);
  655. auto cnode = anf_node->cast<CNodePtr>();
  656. MS_EXCEPTION_IF_NULL(cnode);
  657. if (input_idx + 1 >= cnode->size()) {
  658. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  659. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  660. }
  661. auto input_node = cnode->input(input_idx + 1);
  662. if (!IsValueNode<tensor::Tensor>(input_node)) {
  663. return false;
  664. }
  665. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  666. if (tensor == nullptr) {
  667. MS_LOG(DEBUG) << "Value of input node is nullptr, op: [" << input_node->DebugString() << "]";
  668. return false;
  669. }
  670. auto type_id = tensor->data_type();
  671. auto *data = tensor->data_c();
  672. MS_EXCEPTION_IF_NULL(data);
  673. if (tensor->DataSize() > 1) {
  674. // not const tensor.
  675. MS_LOG(WARNING) << "Not take value of tensor whose datasize greater than 1, [" << input_node->DebugString(2) << "]";
  676. return false;
  677. }
  678. if (type_id == kFloat64->type_id()) {
  679. (*node_json)["value"] = static_cast<double *>(data)[0];
  680. } else if (type_id == kFloat32->type_id()) {
  681. (*node_json)["value"] = static_cast<float *>(data)[0];
  682. } else if (type_id == kFloat16->type_id()) {
  683. float16 *val = static_cast<float16 *>(data);
  684. (*node_json)["value"] = static_cast<float>(val[0]);
  685. } else if (type_id == kUInt64->type_id()) {
  686. (*node_json)["value"] = static_cast<uint64_t *>(data)[0];
  687. } else if (type_id == kUInt32->type_id()) {
  688. (*node_json)["value"] = static_cast<uint32_t *>(data)[0];
  689. } else if (type_id == kUInt16->type_id()) {
  690. (*node_json)["value"] = static_cast<uint16_t *>(data)[0];
  691. } else if (type_id == kUInt8->type_id()) {
  692. (*node_json)["value"] = static_cast<uint8_t *>(data)[0];
  693. } else if (type_id == kInt64->type_id()) {
  694. (*node_json)["value"] = static_cast<int64_t *>(data)[0];
  695. } else if (type_id == kInt32->type_id()) {
  696. (*node_json)["value"] = static_cast<int32_t *>(data)[0];
  697. } else if (type_id == kInt16->type_id()) {
  698. (*node_json)["value"] = static_cast<int16_t *>(data)[0];
  699. } else if (type_id == kInt8->type_id()) {
  700. (*node_json)["value"] = static_cast<int8_t *>(data)[0];
  701. } else if (type_id == kBool->type_id()) {
  702. (*node_json)["value"] = static_cast<bool *>(data)[0];
  703. } else {
  704. MS_LOG(EXCEPTION) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  705. }
  706. return true;
  707. }
  708. bool IsWeightBoundary(const AnfNodePtr &node) {
  709. if (node->isa<ValueNode>()) {
  710. return true;
  711. }
  712. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  713. return true;
  714. }
  715. return false;
  716. }
  717. std::vector<int64_t> GetReduceAttrAxis(const CNodePtr &cnode) {
  718. if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
  719. AnfAlgo::GetInputTensorNum(cnode) != 1) {
  720. MS_LOG(EXCEPTION) << "The reduce node [" << cnode->DebugString() << "] is not single input or single output ";
  721. }
  722. std::vector<int64_t> axis;
  723. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  724. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  725. MS_EXCEPTION_IF_NULL(primitive);
  726. auto axis_attr = primitive->GetAttr(kAxis);
  727. if (axis_attr == nullptr) {
  728. MS_LOG(ERROR) << "This node doesn't have axis attr. Node info [" << cnode->DebugString() << "]";
  729. return std::vector<int64_t>();
  730. }
  731. std::vector<int64_t> axis_list;
  732. if (axis_attr->isa<Int64Imm>()) {
  733. (void)axis_list.emplace_back(GetValue<int64_t>(axis_attr));
  734. } else {
  735. axis_list = GetValue<std::vector<int64_t>>(axis_attr);
  736. }
  737. for (const auto &elem : axis_list) {
  738. if (elem < 0) {
  739. (void)axis.emplace_back(input_shape.size() + elem);
  740. } else {
  741. (void)axis.emplace_back(elem);
  742. }
  743. }
  744. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  745. return axis;
  746. }
  747. std::string GetProcessorStr(const AnfNodePtr &anf_node) {
  748. MS_EXCEPTION_IF_NULL(anf_node);
  749. std::string processor = kProcessorUnknown;
  750. auto kernel_info = dynamic_cast<device::KernelInfo *>(anf_node->kernel_info());
  751. MS_EXCEPTION_IF_NULL(kernel_info);
  752. auto build_info = kernel_info->select_kernel_build_info();
  753. // we may call this before kernel select.
  754. if (build_info == nullptr) {
  755. return processor;
  756. }
  757. switch (build_info->processor()) {
  758. case Processor::AICORE:
  759. processor = kProcessorAiCore;
  760. break;
  761. case Processor::AICPU:
  762. processor = kProcessorAiCpu;
  763. break;
  764. case Processor::CUDA:
  765. processor = kProcessorCuda;
  766. break;
  767. default:
  768. MS_LOG(ERROR) << "Unknown processor type.";
  769. break;
  770. }
  771. return processor;
  772. }
  773. Processor GetProcessorFromContext() {
  774. kernel::Processor processor = kernel::Processor::UNKNOWN;
  775. auto context_ptr = MsContext::GetInstance();
  776. MS_EXCEPTION_IF_NULL(context_ptr);
  777. auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
  778. if (device_info == kGPUDevice) {
  779. processor = kernel::Processor::CUDA;
  780. } else if (device_info == kAscendDevice) {
  781. processor = kernel::Processor::AICORE;
  782. }
  783. return processor;
  784. }
  785. std::string GetStrProcessorFromContext() {
  786. auto processor = GetProcessorFromContext();
  787. string str_processor = kernel::kProcessorUnknown;
  788. if (processor == kernel::Processor::CUDA) {
  789. str_processor = kernel::kProcessorCuda;
  790. } else if (processor == kernel::Processor::AICORE) {
  791. str_processor = kernel::kProcessorAiCore;
  792. }
  793. return str_processor;
  794. }
  795. float Scaling(size_t in_size, size_t out_size, bool align_corners) {
  796. return (align_corners && out_size > 1) ? (in_size - 1) / static_cast<float>(out_size - 1)
  797. : in_size / static_cast<float>(out_size);
  798. }
  799. float ScaleGrid(const int x, const float scale) { return static_cast<float>(x) * scale; }
  800. void ComputeInterpolationWeights(const size_t out_size, const size_t in_size, const float scale,
  801. CachedInterpolation *interpolation) {
  802. interpolation[out_size].lower = 0;
  803. interpolation[out_size].upper = 0;
  804. for (size_t i = 0; i <= out_size - 1; ++i) {
  805. const float in = ScaleGrid(i, scale);
  806. const float in_f = std::floor(in);
  807. interpolation[i].lower = std::max(static_cast<size_t>(in_f), static_cast<size_t>(0));
  808. interpolation[i].upper = std::min(static_cast<size_t>(std::ceil(in)), in_size - 1);
  809. interpolation[i].lerp = in - in_f;
  810. }
  811. }
  812. bool GetShapeSize(const std::vector<size_t> &shape, const TypePtr &type_ptr, int64_t *size_i) {
  813. MS_EXCEPTION_IF_NULL(type_ptr);
  814. size_t type_byte = GetTypeByte(type_ptr);
  815. if (type_byte == 0) {
  816. return false;
  817. }
  818. for (size_t j = 0; j < shape.size(); j++) {
  819. size_i[0] = LongMulWithOverflowCheck(size_i[0], static_cast<int>(shape[j]));
  820. }
  821. size_i[0] = LongMulWithOverflowCheck(size_i[0], SizeToInt(type_byte));
  822. return true;
  823. }
  824. void CastShapeSizeToLong(const std::vector<size_t> &shape, std::vector<int64_t> *long_shape) {
  825. MS_EXCEPTION_IF_NULL(long_shape);
  826. std::transform(shape.begin(), shape.end(), std::back_inserter(*long_shape), SizeToLong);
  827. }
  828. void CheckSliceValid(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  829. const std::vector<int64_t> &step, const std::vector<int64_t> &input_shape) {
  830. if (start.size() != stop.size() || start.size() != step.size() || start.size() > input_shape.size()) {
  831. MS_LOG(EXCEPTION)
  832. << "TensorCopySlices requires the length of begin, stride and end must be equal and less than input dimension.";
  833. }
  834. size_t size = start.size();
  835. for (size_t i = 0; i < size; ++i) {
  836. if (stop[i] <= start[i]) {
  837. MS_LOG(EXCEPTION) << "Invalid slice: (" << start[i] << ", " << stop[i] << " ," << step[i] << ")";
  838. }
  839. // Operator need to be generalized in the future. Only support to copy continuous memory now.
  840. if (step[i] != 1) {
  841. MS_LOG(EXCEPTION) << "The element in step only support 1, but got:" << step;
  842. }
  843. }
  844. size_t slice_pos = size;
  845. for (size_t i = 0; i < size; ++i) {
  846. if (stop[i] - start[i] > 1) {
  847. slice_pos = i;
  848. break;
  849. }
  850. }
  851. for (size_t i = slice_pos + 1; i < size; ++i) {
  852. if (stop[i] - start[i] != input_shape[i]) {
  853. MS_LOG(EXCEPTION) << "Only support copy continuous memory now. For example tensor[0, 0:100] is fine, "
  854. "but tensor[0:100, 0] is not supported.";
  855. }
  856. }
  857. }
  858. size_t GetCopySize(const std::vector<int64_t> &dim_offset, const std::vector<int64_t> &start,
  859. const std::vector<int64_t> &stop) {
  860. for (size_t i = 0; i < start.size(); ++i) {
  861. if (stop[i] - start[i] != 1) {
  862. return SizetMulWithOverflowCheck(LongToSize(stop[i] - start[i]), LongToSize(dim_offset[i]));
  863. }
  864. }
  865. return LongToSize(dim_offset[start.size() - 1]);
  866. }
  867. std::vector<int64_t> CalDimOffset(const std::vector<int64_t> &input_shape) {
  868. std::vector<int64_t> dim_offset;
  869. int64_t offset = 1;
  870. for (auto iter = input_shape.rbegin(); iter != input_shape.rend(); ++iter) {
  871. dim_offset.push_back(offset);
  872. offset = offset * (*iter);
  873. }
  874. std::reverse(dim_offset.begin(), dim_offset.end());
  875. return dim_offset;
  876. }
  877. size_t CalOffset(const std::vector<int64_t> &start, const std::vector<int64_t> &stop,
  878. const std::vector<int64_t> &dim_offset) {
  879. size_t size = start.size();
  880. size_t offset = 0;
  881. for (size_t i = 0; i < size; ++i) {
  882. offset += SizetMulWithOverflowCheck(LongToSize(dim_offset[i]), start[i]);
  883. if (stop[i] - start[i] != 1) {
  884. break;
  885. }
  886. }
  887. return offset;
  888. }
  889. size_t UnitSizeInBytes(const mindspore::TypeId &t) {
  890. size_t bytes = 0;
  891. switch (t) {
  892. case kNumberTypeBool:
  893. case kNumberTypeInt8:
  894. case kNumberTypeUInt8:
  895. bytes = sizeof(int8_t);
  896. break;
  897. case kNumberTypeInt16:
  898. case kNumberTypeUInt16:
  899. case kNumberTypeFloat16:
  900. bytes = sizeof(int16_t);
  901. break;
  902. case kNumberTypeInt:
  903. case kNumberTypeUInt:
  904. case kNumberTypeInt32:
  905. case kNumberTypeUInt32:
  906. case kNumberTypeFloat:
  907. case kNumberTypeFloat32:
  908. bytes = sizeof(int32_t);
  909. break;
  910. case kNumberTypeUInt64:
  911. case kNumberTypeInt64:
  912. case kNumberTypeFloat64:
  913. bytes = sizeof(int64_t);
  914. break;
  915. default:
  916. MS_LOG(EXCEPTION) << "Invalid types " << t;
  917. break;
  918. }
  919. return bytes;
  920. }
  921. } // namespace kernel
  922. } // namespace mindspore