You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 39 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <thread>
  23. #include "nlohmann/json.hpp"
  24. #include "session/anf_runtime_algorithm.h"
  25. #include "common/utils.h"
  26. #include "ir/manager.h"
  27. #include "ir/meta_tensor.h"
  28. #include "ir/func_graph.h"
  29. #include "operator/ops.h"
  30. #include "utils/graph_utils.h"
  31. namespace mindspore {
  32. namespace kernel {
  33. constexpr char kAxis[] = "axis";
  34. constexpr char kTypeInt32[] = "Int32";
  35. const std::unordered_map<std::string, TypeId> type_id_maps = {
  36. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  37. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  38. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  39. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  40. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  41. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  42. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  43. {"bool", TypeId::kNumberTypeBool},
  44. };
  45. const std::map<TypeId, std::string> type_id_str_map = {
  46. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  47. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  48. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  49. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  50. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  51. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  52. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  53. {TypeId::kNumberTypeBool, "bool"},
  54. };
  55. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  56. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  57. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  58. };
  59. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  60. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  61. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  62. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  63. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  64. };
  65. const std::unordered_map<std::string, FusionType> fusion_type_maps = {
  66. {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
  67. {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
  68. };
  69. void KernelMeta::Initialize() {
  70. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
  71. // remove old kernel cache
  72. RemoveKernelCache();
  73. #if defined(_WIN32) || defined(_WIN64)
  74. auto ret = mkdir(kernel_meta_path_.c_str());
  75. #else
  76. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  77. #endif
  78. if (ret != 0) {
  79. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  80. }
  81. initialized_ = true;
  82. }
  83. void KernelMeta::RemoveKernelCache() {
  84. DIR *dir = opendir(kernel_meta_path_.c_str());
  85. if (dir == nullptr) {
  86. return;
  87. }
  88. struct dirent *entry;
  89. while ((entry = readdir(dir)) != nullptr) {
  90. std::string kernel_file = entry->d_name;
  91. std::string kernel_file_realpath = kernel_meta_path_ + kernel_file;
  92. (void)remove(kernel_file_realpath.c_str());
  93. }
  94. (void)closedir(dir);
  95. (void)rmdir(kernel_meta_path_.c_str());
  96. }
  97. std::string KernelMeta::Search(const std::string &kernel_name) const {
  98. if (!initialized_) {
  99. return "";
  100. }
  101. auto iter = kernel_meta_map_.find(kernel_name);
  102. if (iter == kernel_meta_map_.end()) {
  103. return "";
  104. } else {
  105. return iter->second;
  106. }
  107. }
  108. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  109. if (!initialized_) {
  110. return false;
  111. }
  112. kernel_meta_map_[kernel_name] = kernel_json;
  113. return true;
  114. }
  115. bool CheckCache(const std::string &kernel_name) {
  116. // check cache.
  117. KernelMeta *bin_map = KernelMeta::GetInstance();
  118. if (bin_map == nullptr) {
  119. MS_LOG(DEBUG) << "kernel cache is invalid.";
  120. return false;
  121. }
  122. std::string kernel_json = bin_map->Search(kernel_name);
  123. bool ret = (!kernel_json.empty());
  124. if (ret) {
  125. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
  126. } else {
  127. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
  128. }
  129. return ret;
  130. }
  131. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  132. // search cache.
  133. KernelMeta *bin_map = KernelMeta::GetInstance();
  134. if (bin_map == nullptr) {
  135. MS_LOG(DEBUG) << "kernel cache is invalid.";
  136. return nullptr;
  137. }
  138. std::string kernel_json = bin_map->Search(kernel_name);
  139. if (!kernel_json.empty()) {
  140. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  141. // just a tmp solution.
  142. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  143. MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "].";
  144. return nullptr;
  145. } else {
  146. return kernel_pack;
  147. }
  148. } else {
  149. MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "].";
  150. return nullptr;
  151. }
  152. }
  153. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  154. MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor;
  155. KernelMeta *bin_map = KernelMeta::GetInstance();
  156. std::string kernel_json;
  157. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  158. kernel_json = kCceKernelMeta;
  159. } else {
  160. kernel_json = bin_map->GetKernelMetaPath();
  161. }
  162. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  163. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  164. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  165. MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "].";
  166. return nullptr;
  167. }
  168. if (bin_map == nullptr) {
  169. MS_LOG(DEBUG) << "kernel cache is invalid.";
  170. return nullptr;
  171. }
  172. if (bin_map->Insert(kernel_name, kernel_json)) {
  173. MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "].";
  174. }
  175. return kernel_pack;
  176. }
  177. TypeId DtypeToTypeId(const std::string &dtypes) {
  178. auto iter = type_id_maps.find(dtypes);
  179. if (iter != type_id_maps.end()) {
  180. return iter->second;
  181. } else {
  182. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  183. }
  184. }
  185. std::string TypeId2String(TypeId type_id) {
  186. auto iter = type_id_str_map.find(type_id);
  187. if (iter == type_id_str_map.end()) {
  188. return std::string(TypeIdLabel(type_id));
  189. }
  190. return iter->second;
  191. }
  192. std::string Dtype2ShortType(const std::string &dtypes) {
  193. auto iter = dtype_shortdtype_map_.find(dtypes);
  194. if (iter != dtype_shortdtype_map_.end()) {
  195. return iter->second;
  196. } else {
  197. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  198. }
  199. }
  200. size_t GetDtypeNbyte(const std::string &dtypes) {
  201. auto iter = dtype_nbyte_map.find(dtypes);
  202. if (iter != dtype_nbyte_map.end()) {
  203. return iter->second;
  204. } else {
  205. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  206. }
  207. }
  208. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  209. size_t builder_idex, const std::vector<int> &dyn_input_sizes,
  210. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  211. MS_EXCEPTION_IF_NULL(builder);
  212. std::vector<TypeId> inputs_device_type;
  213. std::vector<std::string> inputs_format;
  214. size_t dyn_input_idx = 0;
  215. size_t kernel_info_index = 0;
  216. MS_EXCEPTION_IF_NULL(inputs[0]);
  217. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  218. for (const auto &input : inputs) {
  219. MS_EXCEPTION_IF_NULL(input);
  220. std::string param_type = input->param_type();
  221. std::vector<std::string> dtypes = input->dtypes();
  222. std::vector<std::string> formats = input->formats();
  223. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  224. MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size.";
  225. return false;
  226. }
  227. if (param_type == "dynamic") {
  228. if (dyn_input_sizes.empty()) {
  229. MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
  230. return false;
  231. }
  232. for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  233. kernel_info_index++;
  234. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  235. inputs_device_type.push_back(type_id);
  236. inputs_format.push_back(formats[builder_idex]);
  237. }
  238. dyn_input_idx++;
  239. } else if (param_type == "required") {
  240. kernel_info_index++;
  241. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  242. inputs_device_type.push_back(type_id);
  243. inputs_format.push_back(formats[builder_idex]);
  244. } else {
  245. if (kernel_info_index < real_input_num) {
  246. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  247. kernel_info_index++;
  248. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  249. inputs_device_type.push_back(type_id);
  250. inputs_format.push_back(formats[builder_idex]);
  251. }
  252. }
  253. }
  254. builder->SetInputsDeviceType(inputs_device_type);
  255. builder->SetInputsFormat(inputs_format);
  256. return true;
  257. }
  258. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  259. const size_t &real_output_num,
  260. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  261. // not now but in the next we need to support dynamic output case
  262. MS_EXCEPTION_IF_NULL(builder);
  263. size_t output_idx = 0;
  264. std::vector<TypeId> outputs_device_type;
  265. std::vector<std::string> outputs_format;
  266. MS_EXCEPTION_IF_NULL(outputs[0]);
  267. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  268. for (const auto &output : outputs) {
  269. MS_EXCEPTION_IF_NULL(output);
  270. if (output_idx >= real_output_num) {
  271. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  272. continue;
  273. }
  274. size_t output_num = 0;
  275. if (output->param_type() == "dynamic") {
  276. if (outputs.size() > 1) {
  277. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  278. }
  279. output_num = real_output_num;
  280. } else if (output->param_type() == "required") {
  281. output_num = 1;
  282. } else {
  283. if (output_idx < real_output_num) {
  284. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  285. output_num = 1;
  286. }
  287. }
  288. for (size_t i = 0; i < output_num; i++) {
  289. std::vector<std::string> dtypes = output->dtypes();
  290. std::vector<std::string> formats = output->formats();
  291. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  292. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  293. return false;
  294. }
  295. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  296. outputs_device_type.push_back(type_id);
  297. outputs_format.push_back(formats[builder_idex]);
  298. output_idx++;
  299. }
  300. }
  301. builder->SetOutputsFormat(outputs_format);
  302. builder->SetOutputsDeviceType(outputs_device_type);
  303. return true;
  304. }
  305. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  306. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  307. MS_EXCEPTION_IF_NULL(builder);
  308. MS_EXCEPTION_IF_NULL(op_info_ptr);
  309. auto imply_type = op_info_ptr->imply_type();
  310. builder->SetProcessor(processor);
  311. std::string fusion_type = op_info_ptr->fusion_type();
  312. auto iter = fusion_type_maps.find(fusion_type);
  313. if (iter != fusion_type_maps.end()) {
  314. builder->SetFusionType(iter->second);
  315. } else {
  316. if (imply_type == kAKG) {
  317. MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type;
  318. }
  319. }
  320. if (imply_type == kAKG) {
  321. builder->SetKernelType(AKG_KERNEL);
  322. } else if (imply_type == kAICPU) {
  323. builder->SetKernelType(AICPU_KERNEL);
  324. } else {
  325. builder->SetKernelType(TBE_KERNEL);
  326. }
  327. }
  328. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  329. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  330. MS_EXCEPTION_IF_NULL(kernel_node);
  331. MS_EXCEPTION_IF_NULL(kernel_info_list);
  332. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  333. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  334. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  335. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  336. std::vector<int> dyn_input_sizes;
  337. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  338. MS_EXCEPTION_IF_NULL(primitive);
  339. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  340. dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
  341. }
  342. if (inputs.size() > 0) {
  343. MS_EXCEPTION_IF_NULL(inputs[0]);
  344. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  345. for (size_t j = 0; j < kernel_info_cnt; j++) {
  346. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  347. MS_EXCEPTION_IF_NULL(builder);
  348. SetKernelBuildInfo(builder, processor, op_info_ptr);
  349. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  350. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed.";
  351. return false;
  352. }
  353. if (outputs.size() > 0) {
  354. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  355. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  356. return false;
  357. }
  358. }
  359. kernel_info_list->push_back(builder->Build());
  360. }
  361. } else if (outputs.size() > 0) {
  362. MS_EXCEPTION_IF_NULL(outputs[0]);
  363. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  364. for (size_t j = 0; j < kernel_info_cnt; j++) {
  365. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  366. MS_EXCEPTION_IF_NULL(builder);
  367. SetKernelBuildInfo(builder, processor, op_info_ptr);
  368. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  369. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  370. return false;
  371. }
  372. kernel_info_list->push_back(builder->Build());
  373. }
  374. } else {
  375. if (processor == AICPU) {
  376. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  377. MS_EXCEPTION_IF_NULL(builder);
  378. SetKernelBuildInfo(builder, processor, op_info_ptr);
  379. kernel_info_list->push_back(builder->Build());
  380. }
  381. }
  382. return true;
  383. }
  384. void SaveJsonInfo(const std::string &json_name, const std::string &info) {
  385. char real_path[PATH_MAX] = {0};
  386. std::string path = kCceKernelMeta + json_name + kInfoSuffix;
  387. if (path.size() > PATH_MAX) {
  388. MS_LOG(DEBUG) << "file path " << path << " is too long.";
  389. return;
  390. }
  391. std::ofstream filewrite;
  392. filewrite.open(path);
  393. if (!filewrite.is_open()) {
  394. return;
  395. }
  396. filewrite << info << std::endl;
  397. filewrite.close();
  398. #if defined(_WIN32) || defined(_WIN64)
  399. if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) {
  400. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  401. return;
  402. }
  403. #else
  404. if (nullptr == realpath(path.c_str(), real_path)) {
  405. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  406. return;
  407. }
  408. #endif
  409. MS_LOG(INFO) << "real path is :" << real_path;
  410. if (chmod(real_path, S_IRUSR) == -1) {
  411. MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail.";
  412. }
  413. }
  414. std::string GetProcessor(const AnfNodePtr &anf_node) {
  415. MS_EXCEPTION_IF_NULL(anf_node);
  416. std::string device;
  417. switch (AnfAlgo::GetProcessor(anf_node)) {
  418. case Processor::AICORE:
  419. device = kProcessorAiCore;
  420. break;
  421. case Processor::AICPU:
  422. device = kProcessorAiCpu;
  423. break;
  424. case Processor::CUDA:
  425. device = kProcessorCuda;
  426. break;
  427. default:
  428. MS_LOG(DEBUG) << "Unknown processor type.";
  429. break;
  430. }
  431. return device;
  432. }
  433. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  434. if (shape_a.size() != shape_b.size()) {
  435. return false;
  436. }
  437. for (size_t i = 0; i < shape_a.size(); ++i) {
  438. if (shape_a[i] != shape_b[i]) {
  439. return false;
  440. }
  441. }
  442. return true;
  443. }
  444. int Sign(float x) {
  445. if (x > 0) {
  446. return 1;
  447. }
  448. if (x < 0) {
  449. return -1;
  450. }
  451. return 0;
  452. }
  453. void DeduplicateIndexedSlices(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
  454. size_t outer_dim) {
  455. MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
  456. MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
  457. MS_EXCEPTION_IF_NULL(unique_grad);
  458. MS_EXCEPTION_IF_NULL(unique_grad->value_);
  459. MS_EXCEPTION_IF_NULL(unique_grad->indices_);
  460. std::unordered_map<int, size_t> index_map;
  461. size_t unique_indices_size = 0;
  462. for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) {
  463. int index = origin_sparse_grad.indices_[i];
  464. if (index < 0 || IntToSize(index) >= first_dim) {
  465. continue;
  466. }
  467. auto iter = index_map.find(index);
  468. if (iter == index_map.end()) {
  469. index_map[index] = unique_indices_size;
  470. unique_grad->indices_[unique_indices_size] = index;
  471. size_t start_index = unique_indices_size * outer_dim;
  472. size_t end_index = start_index + outer_dim;
  473. for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) {
  474. unique_grad->value_[j] = origin_sparse_grad.value_[k];
  475. }
  476. unique_indices_size++;
  477. } else {
  478. size_t first_index = iter->second;
  479. size_t start_index = first_index * outer_dim;
  480. size_t end_index = start_index + outer_dim;
  481. for (size_t j = start_index, k = i * outer_dim; j < end_index; ++j, ++k) {
  482. unique_grad->value_[j] += origin_sparse_grad.value_[k];
  483. }
  484. }
  485. }
  486. unique_grad->indices_size_ = unique_indices_size;
  487. }
  488. struct WorkerParamsForReduceSparseGradient {
  489. size_t slice_start_{0};
  490. size_t slice_end_{0};
  491. size_t max_length_{0};
  492. size_t outer_dim_{0};
  493. std::vector<std::pair<int, size_t>> *sorted_indices_{nullptr};
  494. std::vector<size_t> *slice_positions_{nullptr};
  495. float *src_value_{nullptr};
  496. SparseGradient *unique_grad_{nullptr};
  497. };
  498. void WorkerForReduceSparseGradient(WorkerParamsForReduceSparseGradient param) {
  499. MS_EXCEPTION_IF_NULL(param.sorted_indices_);
  500. MS_EXCEPTION_IF_NULL(param.slice_positions_);
  501. MS_EXCEPTION_IF_NULL(param.src_value_);
  502. MS_EXCEPTION_IF_NULL(param.unique_grad_);
  503. auto outer_dim = param.outer_dim_;
  504. auto &sorted_indices = *(param.sorted_indices_);
  505. auto &slice_positions = *(param.slice_positions_);
  506. auto unique_grad = param.unique_grad_;
  507. for (size_t slice_id = param.slice_start_; slice_id < param.slice_end_; ++slice_id) {
  508. size_t cur_pos = slice_positions[slice_id];
  509. int index = sorted_indices[cur_pos].first;
  510. unique_grad->indices_[slice_id] = index;
  511. size_t start_index = slice_id * outer_dim;
  512. auto ret_code = memcpy_s(unique_grad->value_ + start_index, (param.max_length_ - start_index) * sizeof(float),
  513. param.src_value_ + sorted_indices[cur_pos].second, outer_dim * sizeof(float));
  514. if (ret_code != EOK) {
  515. MS_LOG(EXCEPTION) << "Failed to copy data!";
  516. }
  517. cur_pos++;
  518. size_t end_pos;
  519. if (slice_id + 1 < slice_positions.size()) {
  520. end_pos = slice_positions[slice_id + 1];
  521. } else {
  522. end_pos = sorted_indices.size();
  523. }
  524. while (cur_pos < end_pos) {
  525. for (size_t i = 0; i < outer_dim; ++i) {
  526. unique_grad->value_[start_index + i] += param.src_value_[sorted_indices[cur_pos].second + i];
  527. }
  528. cur_pos++;
  529. }
  530. }
  531. }
  532. void RunMultiThreadReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad,
  533. size_t outer_dim, std::vector<std::pair<int, size_t>> *sorted_indices,
  534. std::vector<size_t> *slice_positions) {
  535. MS_LOG(DEBUG) << "Start";
  536. size_t thread_num = 24;
  537. if (slice_positions->size() < thread_num) {
  538. thread_num = slice_positions->size();
  539. }
  540. size_t stride = (slice_positions->size() + thread_num - 1) / thread_num;
  541. thread_num = (slice_positions->size() + stride - 1) / stride;
  542. std::vector<std::thread> threads;
  543. size_t max_length = sorted_indices->size() * outer_dim;
  544. for (size_t i = 0; i < thread_num; ++i) {
  545. size_t slice_start = i * stride;
  546. size_t slice_end = 0;
  547. if (i == thread_num - 1) {
  548. slice_end = slice_positions->size();
  549. } else {
  550. slice_end = slice_start + stride;
  551. }
  552. WorkerParamsForReduceSparseGradient params{
  553. slice_start, slice_end, max_length, outer_dim, sorted_indices, slice_positions, origin_sparse_grad.value_,
  554. unique_grad};
  555. threads.emplace_back(std::thread(WorkerForReduceSparseGradient, params));
  556. }
  557. for (size_t i = 0; i < thread_num; ++i) {
  558. threads[i].join();
  559. }
  560. MS_LOG(DEBUG) << "End";
  561. }
  562. void ReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *unique_grad, size_t first_dim,
  563. size_t outer_dim, bool use_multi_threads) {
  564. MS_LOG(DEBUG) << "Start";
  565. MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
  566. MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
  567. MS_EXCEPTION_IF_NULL(unique_grad);
  568. MS_EXCEPTION_IF_NULL(unique_grad->value_);
  569. MS_EXCEPTION_IF_NULL(unique_grad->indices_);
  570. std::vector<std::pair<int, size_t>> sorted_indices;
  571. sorted_indices.reserve(origin_sparse_grad.indices_size_);
  572. for (size_t i = 0; i < origin_sparse_grad.indices_size_; ++i) {
  573. int index = origin_sparse_grad.indices_[i];
  574. if (index >= 0 && IntToSize(index) < first_dim) {
  575. sorted_indices.emplace_back(std::pair<int, size_t>(index, i * outer_dim));
  576. }
  577. }
  578. std::sort(
  579. sorted_indices.begin(), sorted_indices.end(),
  580. [](const std::pair<int, size_t> &left, const std::pair<int, size_t> &right) { return left.first < right.first; });
  581. int last_index = 0;
  582. std::vector<size_t> slice_positions;
  583. slice_positions.reserve(sorted_indices.size());
  584. for (size_t i = 0; i < sorted_indices.size(); ++i) {
  585. if (i == 0 || last_index != sorted_indices[i].first) {
  586. slice_positions.emplace_back(i);
  587. }
  588. last_index = sorted_indices[i].first;
  589. }
  590. if (use_multi_threads) {
  591. RunMultiThreadReduceSparseGradient(origin_sparse_grad, unique_grad, outer_dim, &sorted_indices, &slice_positions);
  592. } else {
  593. size_t max_length = sorted_indices.size() * outer_dim;
  594. WorkerParamsForReduceSparseGradient params{0,
  595. slice_positions.size(),
  596. max_length,
  597. outer_dim,
  598. &sorted_indices,
  599. &slice_positions,
  600. origin_sparse_grad.value_,
  601. unique_grad};
  602. WorkerForReduceSparseGradient(params);
  603. }
  604. unique_grad->indices_size_ = slice_positions.size();
  605. MS_LOG(DEBUG) << "End";
  606. }
  607. void ReduceMultiSparseGradient(const std::vector<std::shared_ptr<SparseGradient>> &unique_slice_grads,
  608. SparseGradient *tmp_grad, SparseGradient *unique_grad, size_t first_dim,
  609. size_t outer_dim) {
  610. MS_LOG(DEBUG) << "Start";
  611. if (unique_slice_grads.empty()) {
  612. return;
  613. }
  614. size_t index_data_size = outer_dim * sizeof(float);
  615. size_t unique_indices_size = 0;
  616. for (size_t i = 0; i < unique_slice_grads.size(); ++i) {
  617. auto &slice_grad = unique_slice_grads[i];
  618. auto ret_code = memcpy_s(tmp_grad->value_ + unique_indices_size * outer_dim,
  619. (tmp_grad->indices_size_ - unique_indices_size) * index_data_size, slice_grad->value_,
  620. slice_grad->indices_size_ * index_data_size);
  621. if (ret_code != EOK) {
  622. MS_LOG(EXCEPTION) << "Failed to copy data!";
  623. }
  624. ret_code =
  625. memcpy_s(tmp_grad->indices_ + unique_indices_size, (tmp_grad->indices_size_ - unique_indices_size) * sizeof(int),
  626. slice_grad->indices_, slice_grad->indices_size_ * sizeof(int));
  627. if (ret_code != EOK) {
  628. MS_LOG(EXCEPTION) << "Failed to copy data!";
  629. }
  630. unique_indices_size += slice_grad->indices_size_;
  631. }
  632. tmp_grad->indices_size_ = unique_indices_size;
  633. ReduceSparseGradient(*tmp_grad, unique_grad, first_dim, outer_dim);
  634. MS_LOG(DEBUG) << "End";
  635. }
  636. void TwoLevelReduceSparseGradient(const SparseGradient &origin_sparse_grad, SparseGradient *tmp_grad,
  637. SparseGradient *unique_grad, size_t first_dim, size_t outer_dim) {
  638. MS_LOG(DEBUG) << "Start";
  639. MS_EXCEPTION_IF_NULL(origin_sparse_grad.value_);
  640. MS_EXCEPTION_IF_NULL(origin_sparse_grad.indices_);
  641. MS_EXCEPTION_IF_NULL(unique_grad);
  642. MS_EXCEPTION_IF_NULL(unique_grad->value_);
  643. MS_EXCEPTION_IF_NULL(unique_grad->indices_);
  644. MS_EXCEPTION_IF_NULL(tmp_grad);
  645. MS_EXCEPTION_IF_NULL(tmp_grad->value_);
  646. MS_EXCEPTION_IF_NULL(tmp_grad->indices_);
  647. size_t thread_num = 24;
  648. if (origin_sparse_grad.indices_size_ < thread_num) {
  649. thread_num = origin_sparse_grad.indices_size_;
  650. }
  651. size_t thread_indices_size = origin_sparse_grad.indices_size_ / thread_num;
  652. size_t left_indices_size = origin_sparse_grad.indices_size_ % thread_num;
  653. std::vector<std::thread> threads;
  654. threads.reserve(thread_num);
  655. std::vector<std::shared_ptr<SparseGradient>> unique_slice_grads;
  656. for (size_t i = 0; i < thread_num; ++i) {
  657. size_t indices_size = thread_indices_size;
  658. if (i == thread_num - 1) {
  659. indices_size = thread_indices_size + left_indices_size;
  660. }
  661. size_t value_offset = i * thread_indices_size * outer_dim;
  662. size_t indices_offset = i * thread_indices_size;
  663. auto slice_grad = SparseGradient(
  664. {origin_sparse_grad.value_ + value_offset, origin_sparse_grad.indices_ + indices_offset, indices_size});
  665. unique_slice_grads.emplace_back(std::make_shared<SparseGradient>());
  666. unique_slice_grads[i]->value_ = unique_grad->value_ + value_offset;
  667. unique_slice_grads[i]->indices_ = unique_grad->indices_ + indices_offset;
  668. unique_slice_grads[i]->indices_size_ = indices_size;
  669. threads.emplace_back(
  670. std::thread(ReduceSparseGradient, slice_grad, unique_slice_grads[i].get(), first_dim, outer_dim, false));
  671. }
  672. for (size_t i = 0; i < thread_num; ++i) {
  673. threads[i].join();
  674. }
  675. ReduceMultiSparseGradient(unique_slice_grads, tmp_grad, unique_grad, first_dim, outer_dim);
  676. MS_LOG(DEBUG) << "End";
  677. }
  678. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  679. MS_EXCEPTION_IF_NULL(anf_node);
  680. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  681. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
  682. }
  683. auto cnode = anf_node->cast<CNodePtr>();
  684. if (cnode == nullptr) {
  685. return AnfAlgo::VisitKernel(anf_node, 0);
  686. } else {
  687. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  688. }
  689. }
  690. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  691. const std::vector<AnfNodePtr> &input_list) {
  692. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  693. for (size_t i = 0; i < input_list.size(); ++i) {
  694. auto const &input = input_list[i];
  695. MS_EXCEPTION_IF_NULL(input);
  696. bool found = false;
  697. // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
  698. auto mng = input->func_graph()->manager();
  699. MS_EXCEPTION_IF_NULL(mng);
  700. const NodeUsersMap &users = mng->node_users();
  701. auto input_users = users.find(input);
  702. if (input_users == users.end() || input_users->second.empty()) {
  703. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  704. << input->func_graph()->ToString() << "] has no users.";
  705. }
  706. for (auto const &input_user : input_users->second) {
  707. for (auto const &anf_node : node_list) {
  708. if (anf_node != input_user.first) {
  709. continue;
  710. }
  711. std::vector<int> dyn_input_sizes;
  712. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  713. MS_EXCEPTION_IF_NULL(prim);
  714. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  715. dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
  716. }
  717. if (dyn_input_sizes.empty()) {
  718. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  719. found = true;
  720. break;
  721. } else {
  722. int used_as_idx = input_user.second - 1;
  723. int accum_idx = 0;
  724. size_t dyn_i = 0;
  725. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  726. accum_idx += dyn_input_sizes[dyn_i];
  727. if (used_as_idx < accum_idx) {
  728. input_index.push_back(std::make_pair(
  729. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  730. break;
  731. }
  732. }
  733. if (dyn_i != dyn_input_sizes.size()) {
  734. found = true;
  735. break;
  736. }
  737. }
  738. }
  739. if (found) {
  740. break;
  741. }
  742. }
  743. if (!found) {
  744. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  745. << input->func_graph()->ToString() << "] found no related kernel info.";
  746. }
  747. }
  748. return input_index;
  749. }
  750. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  751. const std::vector<AnfNodePtr> &input_list,
  752. const std::vector<AnfNodePtr> &output_list) {
  753. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  754. for (size_t i = 0; i < output_list.size(); ++i) {
  755. auto const &output = output_list[i];
  756. MS_EXCEPTION_IF_NULL(output);
  757. bool found = false;
  758. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  759. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  760. if (pos != std::end(node_list)) {
  761. output_index.push_back(pree_node);
  762. continue;
  763. }
  764. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  765. if (ret != std::end(input_list)) {
  766. output_index.push_back(std::make_pair(pree_node.first, 0));
  767. found = true;
  768. }
  769. if (!found) {
  770. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  771. << output->func_graph()->ToString() << "] found no related kernel info.";
  772. }
  773. }
  774. return output_index;
  775. }
  776. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  777. MS_EXCEPTION_IF_NULL(node_list);
  778. MS_EXCEPTION_IF_NULL(func_graph);
  779. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  780. for (auto const &node : node_lists) {
  781. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  782. continue;
  783. }
  784. auto cnode = node->cast<CNodePtr>();
  785. MS_EXCEPTION_IF_NULL(cnode);
  786. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  787. node_list->push_back(node);
  788. }
  789. }
  790. }
  791. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  792. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  793. MS_EXCEPTION_IF_NULL(node_list);
  794. MS_EXCEPTION_IF_NULL(input_list);
  795. MS_EXCEPTION_IF_NULL(output_list);
  796. MS_EXCEPTION_IF_NULL(func_graph);
  797. GetValidKernelNodes(func_graph, node_list);
  798. auto parameters = func_graph->parameters();
  799. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  800. auto func_output = func_graph->output();
  801. MS_EXCEPTION_IF_NULL(func_output);
  802. if (func_output->isa<CNode>()) {
  803. // multi output.
  804. auto cnode = func_output->cast<CNodePtr>();
  805. MS_EXCEPTION_IF_NULL(cnode);
  806. auto input0 = cnode->input(kAnfPrimitiveIndex);
  807. MS_EXCEPTION_IF_NULL(input0);
  808. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  809. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  810. auto input_node = cnode->input(input_idx);
  811. MS_EXCEPTION_IF_NULL(input_node);
  812. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  813. }
  814. } else {
  815. // single output.
  816. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  817. }
  818. } else {
  819. // single output.
  820. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  821. }
  822. }
  823. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  824. MS_EXCEPTION_IF_NULL(anf_node);
  825. MS_EXCEPTION_IF_NULL(node_json);
  826. auto cnode = anf_node->cast<CNodePtr>();
  827. MS_EXCEPTION_IF_NULL(cnode);
  828. if (input_idx + 1 >= cnode->size()) {
  829. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  830. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  831. }
  832. auto input_node = cnode->input(input_idx + 1);
  833. if (!IsValueNode<tensor::Tensor>(input_node)) {
  834. return false;
  835. }
  836. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  837. if (tensor == nullptr) {
  838. return false;
  839. }
  840. auto type_id = tensor->data_type();
  841. auto *data = tensor->data_c();
  842. MS_EXCEPTION_IF_NULL(data);
  843. if (tensor->DataDim() > 1 || tensor->DataSize() != 1) {
  844. // not const tensor.
  845. MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]";
  846. }
  847. if (type_id == kFloat32->type_id()) {
  848. float *val = static_cast<float *>(data);
  849. MS_EXCEPTION_IF_NULL(val);
  850. (*node_json)["value"] = val[0];
  851. MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
  852. return true;
  853. } else if (type_id == kFloat16->type_id()) {
  854. float16 *val = static_cast<float16 *>(data);
  855. MS_EXCEPTION_IF_NULL(val);
  856. (*node_json)["value"] = static_cast<float>(val[0]);
  857. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
  858. return true;
  859. } else if (type_id == kInt32->type_id()) {
  860. int *val = static_cast<int *>(data);
  861. MS_EXCEPTION_IF_NULL(val);
  862. (*node_json)["value"] = val[0];
  863. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
  864. return true;
  865. }
  866. MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  867. return false;
  868. }
  869. void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
  870. MS_EXCEPTION_IF_NULL(func_graph);
  871. MS_EXCEPTION_IF_NULL(node_list);
  872. auto output = func_graph->output();
  873. MS_EXCEPTION_IF_NULL(output);
  874. if (AnfAlgo::IsRealKernel(output)) {
  875. // single output.
  876. node_list->push_back(std::make_pair(output, 0));
  877. return;
  878. } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
  879. auto output_cnode = output->cast<CNodePtr>();
  880. MS_EXCEPTION_IF_NULL(output_cnode);
  881. // multi output.
  882. auto &inputs = output_cnode->inputs();
  883. for (size_t i = 1; i < inputs.size(); ++i) {
  884. auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
  885. node_list->push_back(in_with_idx);
  886. }
  887. return;
  888. }
  889. MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
  890. << " of graph: " << func_graph->ToString();
  891. }
  892. bool IsWeightBoundary(const AnfNodePtr &node) {
  893. if (node->isa<ValueNode>()) {
  894. return true;
  895. }
  896. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  897. return true;
  898. }
  899. return false;
  900. }
  901. void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params,
  902. size_t total_compute_size) {
  903. const size_t kThreadNum = 24;
  904. std::vector<std::thread> threads;
  905. threads.reserve(kThreadNum);
  906. size_t start = 0;
  907. size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum;
  908. while (start < total_compute_size) {
  909. size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
  910. threads.emplace_back(std::thread(func, params, start, end));
  911. start += once_compute_size;
  912. }
  913. for (size_t i = 0; i < threads.size(); ++i) {
  914. threads[i].join();
  915. }
  916. }
  917. std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
  918. if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
  919. AnfAlgo::GetInputTensorNum(cnode) != 1) {
  920. MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
  921. << "] is not single input or single output ";
  922. }
  923. std::vector<int> axis;
  924. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  925. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  926. MS_EXCEPTION_IF_NULL(primitive);
  927. auto axis_attr = primitive->GetAttr(kAxis);
  928. if (axis_attr == nullptr) {
  929. MS_LOG(ERROR) << "This node does't have axie attr.";
  930. return std::vector<int>();
  931. }
  932. auto type = axis_attr->type();
  933. MS_EXCEPTION_IF_NULL(type);
  934. std::vector<int> axis_list;
  935. if (type->ToString() == kTypeInt32) {
  936. axis_list.emplace_back(GetValue<int>(axis_attr));
  937. } else {
  938. axis_list = GetValue<std::vector<int>>(axis_attr);
  939. }
  940. for (const auto &elem : axis_list) {
  941. if (elem < 0) {
  942. axis.emplace_back(input_shape.size() + elem);
  943. } else {
  944. axis.emplace_back(elem);
  945. }
  946. }
  947. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  948. return axis;
  949. }
  950. } // namespace kernel
  951. } // namespace mindspore