You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_utils.h"
  27. #include "ir/manager.h"
  28. #include "ir/meta_tensor.h"
  29. #include "ir/func_graph.h"
  30. #include "frontend/operator/ops.h"
  31. #include "ir/graph_utils.h"
  32. namespace mindspore {
  33. namespace kernel {
  34. constexpr char kAxis[] = "axis";
  35. constexpr char kTypeInt32[] = "Int32";
  36. const std::unordered_map<std::string, TypeId> type_id_maps = {
  37. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  38. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  39. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  40. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  41. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  42. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  43. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  44. {"bool", TypeId::kNumberTypeBool},
  45. };
  46. const std::map<TypeId, std::string> type_id_str_map = {
  47. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  48. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  49. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  50. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  51. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  52. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  53. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  54. {TypeId::kNumberTypeBool, "bool"},
  55. };
  56. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  57. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  58. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  59. };
  60. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  61. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  62. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  63. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  64. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  65. };
  66. const std::unordered_map<std::string, FusionType> fusion_type_maps = {
  67. {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
  68. {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
  69. };
  70. void KernelMeta::Initialize(int pid) {
  71. if (pid == -1) {
  72. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
  73. } else {
  74. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/";
  75. }
  76. // remove old kernel cache
  77. RemoveKernelCache();
  78. #if defined(_WIN32) || defined(_WIN64)
  79. auto ret = mkdir(kernel_meta_path_.c_str());
  80. #else
  81. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  82. #endif
  83. if (ret != 0) {
  84. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  85. }
  86. initialized_ = true;
  87. }
  88. void KernelMeta::RemoveKernelCache() {
  89. DIR *dir = opendir(kernel_meta_path_.c_str());
  90. if (dir == nullptr) {
  91. return;
  92. }
  93. struct dirent *entry;
  94. while ((entry = readdir(dir)) != nullptr) {
  95. std::string kernel_file = entry->d_name;
  96. std::string kernel_file_realpath = kernel_meta_path_ + kernel_file;
  97. (void)remove(kernel_file_realpath.c_str());
  98. }
  99. (void)closedir(dir);
  100. (void)rmdir(kernel_meta_path_.c_str());
  101. }
  102. std::string KernelMeta::Search(const std::string &kernel_name) const {
  103. if (!initialized_) {
  104. return "";
  105. }
  106. auto iter = kernel_meta_map_.find(kernel_name);
  107. if (iter == kernel_meta_map_.end()) {
  108. return "";
  109. } else {
  110. return iter->second;
  111. }
  112. }
  113. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  114. if (!initialized_) {
  115. return false;
  116. }
  117. kernel_meta_map_[kernel_name] = kernel_json;
  118. return true;
  119. }
  120. bool CheckCache(const std::string &kernel_name) {
  121. // check cache.
  122. KernelMeta *bin_map = KernelMeta::GetInstance();
  123. if (bin_map == nullptr) {
  124. MS_LOG(DEBUG) << "kernel cache is invalid.";
  125. return false;
  126. }
  127. std::string kernel_json = bin_map->Search(kernel_name);
  128. bool ret = (!kernel_json.empty());
  129. if (ret) {
  130. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
  131. } else {
  132. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
  133. }
  134. return ret;
  135. }
  136. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  137. // search cache.
  138. KernelMeta *bin_map = KernelMeta::GetInstance();
  139. if (bin_map == nullptr) {
  140. MS_LOG(DEBUG) << "kernel cache is invalid.";
  141. return nullptr;
  142. }
  143. std::string kernel_json = bin_map->Search(kernel_name);
  144. if (!kernel_json.empty()) {
  145. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  146. // just a tmp solution.
  147. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  148. MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "].";
  149. return nullptr;
  150. } else {
  151. return kernel_pack;
  152. }
  153. } else {
  154. MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "].";
  155. return nullptr;
  156. }
  157. }
  158. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  159. MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor;
  160. KernelMeta *bin_map = KernelMeta::GetInstance();
  161. std::string kernel_json;
  162. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  163. kernel_json = kCceKernelMeta;
  164. } else {
  165. kernel_json = bin_map->GetKernelMetaPath();
  166. }
  167. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  168. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  169. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  170. MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "].";
  171. return nullptr;
  172. }
  173. if (bin_map == nullptr) {
  174. MS_LOG(DEBUG) << "kernel cache is invalid.";
  175. return nullptr;
  176. }
  177. if (bin_map->Insert(kernel_name, kernel_json)) {
  178. MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "].";
  179. }
  180. return kernel_pack;
  181. }
  182. TypeId DtypeToTypeId(const std::string &dtypes) {
  183. auto iter = type_id_maps.find(dtypes);
  184. if (iter != type_id_maps.end()) {
  185. return iter->second;
  186. } else {
  187. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  188. }
  189. }
  190. std::string TypeId2String(TypeId type_id) {
  191. auto iter = type_id_str_map.find(type_id);
  192. if (iter == type_id_str_map.end()) {
  193. return std::string(TypeIdLabel(type_id));
  194. }
  195. return iter->second;
  196. }
  197. std::string Dtype2ShortType(const std::string &dtypes) {
  198. auto iter = dtype_shortdtype_map_.find(dtypes);
  199. if (iter != dtype_shortdtype_map_.end()) {
  200. return iter->second;
  201. } else {
  202. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  203. }
  204. }
  205. size_t GetDtypeNbyte(const std::string &dtypes) {
  206. auto iter = dtype_nbyte_map.find(dtypes);
  207. if (iter != dtype_nbyte_map.end()) {
  208. return iter->second;
  209. } else {
  210. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  211. }
  212. }
  213. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  214. size_t builder_idex, const std::vector<int> &dyn_input_sizes,
  215. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  216. MS_EXCEPTION_IF_NULL(builder);
  217. std::vector<TypeId> inputs_device_type;
  218. std::vector<std::string> inputs_format;
  219. size_t dyn_input_idx = 0;
  220. size_t kernel_info_index = 0;
  221. MS_EXCEPTION_IF_NULL(inputs[0]);
  222. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  223. for (const auto &input : inputs) {
  224. MS_EXCEPTION_IF_NULL(input);
  225. std::string param_type = input->param_type();
  226. std::vector<std::string> dtypes = input->dtypes();
  227. std::vector<std::string> formats = input->formats();
  228. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  229. MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size.";
  230. return false;
  231. }
  232. if (param_type == "dynamic") {
  233. if (dyn_input_sizes.empty()) {
  234. MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
  235. return false;
  236. }
  237. for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  238. kernel_info_index++;
  239. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  240. inputs_device_type.push_back(type_id);
  241. inputs_format.push_back(formats[builder_idex]);
  242. }
  243. dyn_input_idx++;
  244. } else if (param_type == "required") {
  245. kernel_info_index++;
  246. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  247. inputs_device_type.push_back(type_id);
  248. inputs_format.push_back(formats[builder_idex]);
  249. } else {
  250. if (kernel_info_index < real_input_num) {
  251. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  252. kernel_info_index++;
  253. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  254. inputs_device_type.push_back(type_id);
  255. inputs_format.push_back(formats[builder_idex]);
  256. }
  257. }
  258. }
  259. builder->SetInputsDeviceType(inputs_device_type);
  260. builder->SetInputsFormat(inputs_format);
  261. return true;
  262. }
  263. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  264. const size_t &real_output_num,
  265. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  266. // not now but in the next we need to support dynamic output case
  267. MS_EXCEPTION_IF_NULL(builder);
  268. size_t output_idx = 0;
  269. std::vector<TypeId> outputs_device_type;
  270. std::vector<std::string> outputs_format;
  271. MS_EXCEPTION_IF_NULL(outputs[0]);
  272. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  273. for (const auto &output : outputs) {
  274. MS_EXCEPTION_IF_NULL(output);
  275. if (output_idx >= real_output_num) {
  276. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  277. continue;
  278. }
  279. size_t output_num = 0;
  280. if (output->param_type() == "dynamic") {
  281. if (outputs.size() > 1) {
  282. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  283. }
  284. output_num = real_output_num;
  285. } else if (output->param_type() == "required") {
  286. output_num = 1;
  287. } else {
  288. if (output_idx < real_output_num) {
  289. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  290. output_num = 1;
  291. }
  292. }
  293. for (size_t i = 0; i < output_num; i++) {
  294. std::vector<std::string> dtypes = output->dtypes();
  295. std::vector<std::string> formats = output->formats();
  296. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  297. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  298. return false;
  299. }
  300. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  301. outputs_device_type.push_back(type_id);
  302. outputs_format.push_back(formats[builder_idex]);
  303. output_idx++;
  304. }
  305. }
  306. builder->SetOutputsFormat(outputs_format);
  307. builder->SetOutputsDeviceType(outputs_device_type);
  308. return true;
  309. }
  310. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  311. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  312. MS_EXCEPTION_IF_NULL(builder);
  313. MS_EXCEPTION_IF_NULL(op_info_ptr);
  314. auto imply_type = op_info_ptr->imply_type();
  315. builder->SetProcessor(processor);
  316. std::string fusion_type = op_info_ptr->fusion_type();
  317. auto iter = fusion_type_maps.find(fusion_type);
  318. if (iter != fusion_type_maps.end()) {
  319. builder->SetFusionType(iter->second);
  320. } else {
  321. if (imply_type == kAKG) {
  322. MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type;
  323. }
  324. }
  325. if (imply_type == kAKG) {
  326. builder->SetKernelType(AKG_KERNEL);
  327. } else if (imply_type == kAICPU) {
  328. builder->SetKernelType(AICPU_KERNEL);
  329. } else {
  330. builder->SetKernelType(TBE_KERNEL);
  331. }
  332. }
  333. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  334. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  335. MS_EXCEPTION_IF_NULL(kernel_node);
  336. MS_EXCEPTION_IF_NULL(kernel_info_list);
  337. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  338. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  339. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  340. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  341. std::vector<int> dyn_input_sizes;
  342. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  343. MS_EXCEPTION_IF_NULL(primitive);
  344. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  345. dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
  346. }
  347. if (inputs.size() > 0) {
  348. MS_EXCEPTION_IF_NULL(inputs[0]);
  349. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  350. for (size_t j = 0; j < kernel_info_cnt; j++) {
  351. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  352. MS_EXCEPTION_IF_NULL(builder);
  353. SetKernelBuildInfo(builder, processor, op_info_ptr);
  354. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  355. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed.";
  356. return false;
  357. }
  358. if (outputs.size() > 0) {
  359. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  360. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  361. return false;
  362. }
  363. }
  364. kernel_info_list->push_back(builder->Build());
  365. }
  366. } else if (outputs.size() > 0) {
  367. MS_EXCEPTION_IF_NULL(outputs[0]);
  368. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  369. for (size_t j = 0; j < kernel_info_cnt; j++) {
  370. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  371. MS_EXCEPTION_IF_NULL(builder);
  372. SetKernelBuildInfo(builder, processor, op_info_ptr);
  373. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  374. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  375. return false;
  376. }
  377. kernel_info_list->push_back(builder->Build());
  378. }
  379. } else {
  380. if (processor == AICPU) {
  381. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  382. MS_EXCEPTION_IF_NULL(builder);
  383. SetKernelBuildInfo(builder, processor, op_info_ptr);
  384. kernel_info_list->push_back(builder->Build());
  385. }
  386. }
  387. return true;
  388. }
  389. void SaveJsonInfo(const std::string &json_name, const std::string &info) {
  390. char real_path[PATH_MAX] = {0};
  391. std::string path = kCceKernelMeta + json_name + kInfoSuffix;
  392. if (path.size() > PATH_MAX) {
  393. MS_LOG(DEBUG) << "file path " << path << " is too long.";
  394. return;
  395. }
  396. std::ofstream filewrite;
  397. filewrite.open(path);
  398. if (!filewrite.is_open()) {
  399. return;
  400. }
  401. filewrite << info << std::endl;
  402. filewrite.close();
  403. #if defined(_WIN32) || defined(_WIN64)
  404. if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) {
  405. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  406. return;
  407. }
  408. #else
  409. if (nullptr == realpath(path.c_str(), real_path)) {
  410. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  411. return;
  412. }
  413. #endif
  414. MS_LOG(INFO) << "real path is :" << real_path;
  415. if (chmod(real_path, S_IRUSR) == -1) {
  416. MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail.";
  417. }
  418. }
  419. std::string GetProcessor(const AnfNodePtr &anf_node) {
  420. MS_EXCEPTION_IF_NULL(anf_node);
  421. std::string device;
  422. switch (AnfAlgo::GetProcessor(anf_node)) {
  423. case Processor::AICORE:
  424. device = kProcessorAiCore;
  425. break;
  426. case Processor::AICPU:
  427. device = kProcessorAiCpu;
  428. break;
  429. case Processor::CUDA:
  430. device = kProcessorCuda;
  431. break;
  432. default:
  433. MS_LOG(DEBUG) << "Unknown processor type.";
  434. break;
  435. }
  436. return device;
  437. }
  438. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  439. if (shape_a.size() != shape_b.size()) {
  440. return false;
  441. }
  442. for (size_t i = 0; i < shape_a.size(); ++i) {
  443. if (shape_a[i] != shape_b[i]) {
  444. return false;
  445. }
  446. }
  447. return true;
  448. }
  449. int Sign(float x) {
  450. if (x > 0) {
  451. return 1;
  452. }
  453. if (x < 0) {
  454. return -1;
  455. }
  456. return 0;
  457. }
  458. namespace {
  459. struct BucketSparseGradient {
  460. float *value_;
  461. int *indices_;
  462. int *global_indices_;
  463. size_t indices_size_;
  464. };
  465. struct MultiThreadReduceSparseGradientParam {
  466. SparseGradient *input_grad_{nullptr};
  467. SparseGradient *workspace_grad_{nullptr};
  468. SparseGradient *output_grad_{nullptr};
  469. size_t max_index_{0};
  470. size_t value_stride_{0};
  471. size_t thread_num_{0};
  472. bool use_sort_reduce_{false};
  473. };
  474. void CalculateEachBucketSize(const std::shared_ptr<SparseGradient> &sparse_grad, size_t max_index,
  475. std::vector<size_t> *each_bucket_size) {
  476. MS_LOG(DEBUG) << "Start";
  477. MS_EXCEPTION_IF_NULL(sparse_grad);
  478. MS_EXCEPTION_IF_NULL(sparse_grad->indices_);
  479. MS_EXCEPTION_IF_NULL(each_bucket_size);
  480. size_t bucket_num = each_bucket_size->size();
  481. for (size_t i = 0; i < sparse_grad->indices_size_; ++i) {
  482. int index = sparse_grad->indices_[i];
  483. if (index >= 0 && IntToSize(index) < max_index) {
  484. auto bucket_id = index % bucket_num;
  485. each_bucket_size->at(bucket_id)++;
  486. }
  487. }
  488. MS_LOG(DEBUG) << "End";
  489. }
  490. void SplitAndCalculateSegmentBucketSize(const MultiThreadReduceSparseGradientParam &param,
  491. std::vector<std::shared_ptr<SparseGradient>> *segments_ptr,
  492. std::vector<std::shared_ptr<std::vector<size_t>>> *segment_bucket_sizes_ptr) {
  493. MS_EXCEPTION_IF_NULL(param.input_grad_);
  494. MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr);
  495. MS_EXCEPTION_IF_NULL(segments_ptr);
  496. auto &segments = *segments_ptr;
  497. auto &segment_bucket_sizes = *segment_bucket_sizes_ptr;
  498. auto input_grad = param.input_grad_;
  499. if (param.thread_num_ < 1) {
  500. MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!";
  501. }
  502. size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_;
  503. size_t left_indices_size = input_grad->indices_size_ % param.thread_num_;
  504. std::vector<std::thread> threads;
  505. threads.reserve(param.thread_num_);
  506. segments.reserve(param.thread_num_);
  507. size_t current_indices_offset = 0;
  508. for (size_t i = 0; i < param.thread_num_; ++i) {
  509. segment_bucket_sizes.emplace_back(std::make_shared<std::vector<size_t>>(param.thread_num_, 0));
  510. size_t indices_size = thread_indices_size;
  511. if (i < left_indices_size) {
  512. indices_size += 1;
  513. }
  514. segments.emplace_back(std::make_shared<SparseGradient>());
  515. segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_;
  516. segments[i]->indices_ = input_grad->indices_ + current_indices_offset;
  517. segments[i]->indices_size_ = indices_size;
  518. threads.emplace_back(
  519. std::thread(CalculateEachBucketSize, segments[i], param.max_index_, segment_bucket_sizes[i].get()));
  520. current_indices_offset += indices_size;
  521. }
  522. for (size_t i = 0; i < param.thread_num_; ++i) {
  523. threads[i].join();
  524. }
  525. }
  526. void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam &param,
  527. const std::shared_ptr<SparseGradient> &segment, size_t bucket_offset,
  528. const std::vector<std::shared_ptr<BucketSparseGradient>> &buckets) {
  529. MS_LOG(DEBUG) << "Start";
  530. MS_EXCEPTION_IF_NULL(segment);
  531. MS_EXCEPTION_IF_NULL(segment->indices_);
  532. std::vector<size_t> bucket_data_num(param.thread_num_, 0);
  533. for (size_t i = 0; i < segment->indices_size_; ++i) {
  534. int index = segment->indices_[i];
  535. if (index >= 0 && IntToSize(index) < param.max_index_) {
  536. auto bucket_id = index % param.thread_num_;
  537. auto bucket_index = bucket_data_num[bucket_id];
  538. buckets[bucket_id]->indices_[bucket_index] = index;
  539. buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i;
  540. bucket_data_num[bucket_id]++;
  541. }
  542. }
  543. MS_LOG(DEBUG) << "End";
  544. }
  545. void GatherSegmentIndicesToOutputBucket(const MultiThreadReduceSparseGradientParam &param,
  546. const std::vector<std::shared_ptr<SparseGradient>> &segments,
  547. const std::vector<std::shared_ptr<std::vector<size_t>>> &segment_bucket_sizes,
  548. std::vector<std::shared_ptr<BucketSparseGradient>> *buckets_ptr) {
  549. MS_EXCEPTION_IF_NULL(param.output_grad_);
  550. MS_EXCEPTION_IF_NULL(param.output_grad_->value_);
  551. MS_EXCEPTION_IF_NULL(param.output_grad_->indices_);
  552. MS_EXCEPTION_IF_NULL(buckets_ptr);
  553. auto &buckets = *buckets_ptr;
  554. size_t thread_num = param.thread_num_;
  555. if (thread_num != segment_bucket_sizes.size()) {
  556. MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!";
  557. }
  558. std::vector<size_t> bucket_data_size(thread_num, 0);
  559. for (size_t i = 0; i < thread_num; ++i) {
  560. for (size_t j = 0; j < thread_num; ++j) {
  561. bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
  562. }
  563. }
  564. size_t current_indices_offset = 0;
  565. for (size_t i = 0; i < thread_num; ++i) {
  566. buckets.emplace_back(std::make_shared<BucketSparseGradient>());
  567. buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_;
  568. buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset;
  569. buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset;
  570. buckets[i]->indices_size_ = bucket_data_size[i];
  571. current_indices_offset += bucket_data_size[i];
  572. }
  573. std::vector<size_t> tmp_bucket_data_size(thread_num, 0);
  574. std::vector<std::vector<std::shared_ptr<BucketSparseGradient>>> each_thread_buckets;
  575. for (size_t i = 0; i < thread_num; ++i) {
  576. std::vector<std::shared_ptr<BucketSparseGradient>> thread_buckets;
  577. for (size_t j = 0; j < thread_num; ++j) {
  578. thread_buckets.emplace_back(std::make_shared<BucketSparseGradient>());
  579. thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j];
  580. thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j];
  581. thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_;
  582. thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j);
  583. tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
  584. }
  585. each_thread_buckets.emplace_back(thread_buckets);
  586. }
  587. std::vector<std::thread> threads;
  588. threads.reserve(thread_num);
  589. current_indices_offset = 0;
  590. for (size_t i = 0; i < thread_num; ++i) {
  591. threads.emplace_back(
  592. std::thread(CopySegmentIndicesToBucket, param, segments[i], current_indices_offset, each_thread_buckets[i]));
  593. current_indices_offset += segments[i]->indices_size_;
  594. }
  595. for (size_t i = 0; i < thread_num; ++i) {
  596. threads[i].join();
  597. }
  598. }
  599. void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam &param,
  600. const std::shared_ptr<BucketSparseGradient> &bucket,
  601. const std::shared_ptr<SparseGradient> &reduced_bucket) {
  602. MS_LOG(DEBUG) << "Start";
  603. MS_EXCEPTION_IF_NULL(bucket);
  604. MS_EXCEPTION_IF_NULL(bucket->value_);
  605. MS_EXCEPTION_IF_NULL(bucket->indices_);
  606. MS_EXCEPTION_IF_NULL(reduced_bucket);
  607. MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
  608. MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
  609. std::vector<std::pair<int, int>> sorted_indices;
  610. sorted_indices.reserve(bucket->indices_size_);
  611. for (size_t i = 0; i < bucket->indices_size_; ++i) {
  612. int index = bucket->indices_[i];
  613. int global_index = bucket->global_indices_[i];
  614. sorted_indices.emplace_back(std::pair<int, int>(index, global_index));
  615. }
  616. std::sort(sorted_indices.begin(), sorted_indices.end());
  617. float *global_value = param.input_grad_->value_;
  618. size_t unique_indices_size = 0;
  619. size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
  620. int last_index{0};
  621. size_t value_offset{0};
  622. for (size_t i = 0; i < sorted_indices.size(); ++i) {
  623. int index = sorted_indices[i].first;
  624. int global_index = sorted_indices[i].second;
  625. int global_value_offset = global_index * param.value_stride_;
  626. if (i == 0 || index != last_index) {
  627. if (i != 0) {
  628. unique_indices_size++;
  629. }
  630. reduced_bucket->indices_[unique_indices_size] = index;
  631. value_offset = unique_indices_size * param.value_stride_;
  632. auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float),
  633. global_value + global_value_offset, param.value_stride_ * sizeof(float));
  634. if (ret_code != EOK) {
  635. MS_LOG(EXCEPTION) << "Failed to copy data!";
  636. }
  637. } else {
  638. for (size_t j = 0; j < param.value_stride_; ++j) {
  639. reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j];
  640. }
  641. }
  642. last_index = index;
  643. }
  644. reduced_bucket->indices_size_ = unique_indices_size;
  645. MS_LOG(DEBUG) << "End";
  646. }
  647. void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam &param,
  648. const std::shared_ptr<BucketSparseGradient> &bucket,
  649. const std::shared_ptr<SparseGradient> &reduced_bucket) {
  650. MS_LOG(DEBUG) << "Start";
  651. MS_EXCEPTION_IF_NULL(bucket);
  652. MS_EXCEPTION_IF_NULL(bucket->value_);
  653. MS_EXCEPTION_IF_NULL(bucket->indices_);
  654. MS_EXCEPTION_IF_NULL(reduced_bucket);
  655. MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
  656. MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
  657. float *global_value = param.input_grad_->value_;
  658. std::unordered_map<int, size_t> index_map;
  659. size_t unique_indices_size = 0;
  660. size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
  661. for (size_t i = 0; i < bucket->indices_size_; ++i) {
  662. int index = bucket->indices_[i];
  663. int global_index = bucket->global_indices_[i];
  664. auto iter = index_map.find(index);
  665. if (iter == index_map.end()) {
  666. reduced_bucket->indices_[unique_indices_size] = index;
  667. size_t start_index = unique_indices_size * param.value_stride_;
  668. index_map[index] = start_index;
  669. auto ret_code = memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float),
  670. global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float));
  671. if (ret_code != EOK) {
  672. MS_LOG(EXCEPTION) << "Failed to copy data!";
  673. }
  674. unique_indices_size++;
  675. } else {
  676. size_t start_index = iter->second;
  677. size_t end_index = start_index + param.value_stride_;
  678. for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) {
  679. reduced_bucket->value_[j] += global_value[k];
  680. }
  681. }
  682. }
  683. reduced_bucket->indices_size_ = unique_indices_size;
  684. MS_LOG(DEBUG) << "End";
  685. }
  686. void ReduceBucketSparseGradientToWorkspace(const MultiThreadReduceSparseGradientParam &param,
  687. const std::vector<std::shared_ptr<BucketSparseGradient>> &buckets,
  688. std::vector<std::shared_ptr<SparseGradient>> *reduced_buckets_ptr) {
  689. MS_EXCEPTION_IF_NULL(param.workspace_grad_);
  690. MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_);
  691. MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_);
  692. MS_EXCEPTION_IF_NULL(reduced_buckets_ptr);
  693. auto &reduced_buckets = *reduced_buckets_ptr;
  694. size_t thread_num = buckets.size();
  695. std::vector<std::thread> threads;
  696. threads.reserve(thread_num);
  697. size_t current_indices_offset = 0;
  698. for (size_t i = 0; i < thread_num; ++i) {
  699. reduced_buckets.emplace_back(std::make_shared<SparseGradient>());
  700. reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_;
  701. reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset;
  702. reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_;
  703. if (param.use_sort_reduce_) {
  704. threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i]));
  705. } else {
  706. threads.emplace_back(std::thread(ReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i]));
  707. }
  708. current_indices_offset += buckets[i]->indices_size_;
  709. }
  710. for (size_t i = 0; i < thread_num; ++i) {
  711. threads[i].join();
  712. }
  713. }
  714. void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam &param,
  715. const std::vector<std::shared_ptr<SparseGradient>> &reduced_buckets) {
  716. MS_EXCEPTION_IF_NULL(param.output_grad_);
  717. auto output_grad = param.output_grad_;
  718. MS_EXCEPTION_IF_NULL(output_grad->value_);
  719. MS_EXCEPTION_IF_NULL(output_grad->indices_);
  720. size_t stride_data_size = param.value_stride_ * sizeof(float);
  721. size_t unique_indices_size = 0;
  722. for (size_t i = 0; i < reduced_buckets.size(); ++i) {
  723. auto &bucket = reduced_buckets[i];
  724. MS_EXCEPTION_IF_NULL(bucket);
  725. if (bucket->indices_size_ == 0) {
  726. continue;
  727. }
  728. auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_,
  729. (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_,
  730. bucket->indices_size_ * stride_data_size);
  731. if (ret_code != EOK) {
  732. MS_LOG(EXCEPTION) << "Failed to copy data!";
  733. }
  734. ret_code = memcpy_s(output_grad->indices_ + unique_indices_size,
  735. (output_grad->indices_size_ - unique_indices_size) * sizeof(int), bucket->indices_,
  736. bucket->indices_size_ * sizeof(int));
  737. if (ret_code != EOK) {
  738. MS_LOG(EXCEPTION) << "Failed to copy data!";
  739. }
  740. unique_indices_size += bucket->indices_size_;
  741. }
  742. output_grad->indices_size_ = unique_indices_size;
  743. }
  744. } // namespace
  745. void BucketReduceSparseGradient(const ReduceSparseGradientParam &param) {
  746. MS_LOG(DEBUG) << "Start";
  747. MS_EXCEPTION_IF_NULL(param.input_grad_);
  748. size_t thread_num = 23;
  749. if (param.input_grad_->indices_size_ < thread_num) {
  750. thread_num = param.input_grad_->indices_size_;
  751. }
  752. MultiThreadReduceSparseGradientParam multi_thread_param({param.input_grad_, param.workspace_grad_, param.output_grad_,
  753. param.max_index_, param.value_stride_, thread_num,
  754. param.use_sort_reduce_});
  755. std::vector<std::shared_ptr<SparseGradient>> segments;
  756. std::vector<std::shared_ptr<std::vector<size_t>>> segment_bucket_sizes;
  757. SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes);
  758. std::vector<std::shared_ptr<BucketSparseGradient>> buckets;
  759. GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets);
  760. std::vector<std::shared_ptr<SparseGradient>> reduced_buckets;
  761. ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets);
  762. MergeReduceSparseGradient(multi_thread_param, reduced_buckets);
  763. MS_LOG(DEBUG) << "End";
  764. }
  765. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  766. MS_EXCEPTION_IF_NULL(anf_node);
  767. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  768. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
  769. }
  770. auto cnode = anf_node->cast<CNodePtr>();
  771. if (cnode == nullptr) {
  772. return AnfAlgo::VisitKernel(anf_node, 0);
  773. } else {
  774. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  775. }
  776. }
  777. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  778. const std::vector<AnfNodePtr> &input_list) {
  779. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  780. for (size_t i = 0; i < input_list.size(); ++i) {
  781. auto const &input = input_list[i];
  782. MS_EXCEPTION_IF_NULL(input);
  783. bool found = false;
  784. // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
  785. auto mng = input->func_graph()->manager();
  786. MS_EXCEPTION_IF_NULL(mng);
  787. const NodeUsersMap &users = mng->node_users();
  788. auto input_users = users.find(input);
  789. if (input_users == users.end() || input_users->second.empty()) {
  790. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  791. << input->func_graph()->ToString() << "] has no users.";
  792. }
  793. for (auto const &input_user : input_users->second) {
  794. for (auto const &anf_node : node_list) {
  795. if (anf_node != input_user.first) {
  796. continue;
  797. }
  798. std::vector<int> dyn_input_sizes;
  799. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  800. MS_EXCEPTION_IF_NULL(prim);
  801. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  802. dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
  803. }
  804. if (dyn_input_sizes.empty()) {
  805. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  806. found = true;
  807. break;
  808. } else {
  809. int used_as_idx = input_user.second - 1;
  810. int accum_idx = 0;
  811. size_t dyn_i = 0;
  812. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  813. accum_idx += dyn_input_sizes[dyn_i];
  814. if (used_as_idx < accum_idx) {
  815. input_index.push_back(std::make_pair(
  816. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  817. break;
  818. }
  819. }
  820. if (dyn_i != dyn_input_sizes.size()) {
  821. found = true;
  822. break;
  823. }
  824. }
  825. }
  826. if (found) {
  827. break;
  828. }
  829. }
  830. if (!found) {
  831. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  832. << input->func_graph()->ToString() << "] found no related kernel info.";
  833. }
  834. }
  835. return input_index;
  836. }
  837. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  838. const std::vector<AnfNodePtr> &input_list,
  839. const std::vector<AnfNodePtr> &output_list) {
  840. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  841. for (size_t i = 0; i < output_list.size(); ++i) {
  842. auto const &output = output_list[i];
  843. MS_EXCEPTION_IF_NULL(output);
  844. bool found = false;
  845. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  846. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  847. if (pos != std::end(node_list)) {
  848. output_index.push_back(pree_node);
  849. continue;
  850. }
  851. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  852. if (ret != std::end(input_list)) {
  853. output_index.push_back(std::make_pair(pree_node.first, 0));
  854. found = true;
  855. }
  856. if (!found) {
  857. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  858. << output->func_graph()->ToString() << "] found no related kernel info.";
  859. }
  860. }
  861. return output_index;
  862. }
  863. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  864. MS_EXCEPTION_IF_NULL(node_list);
  865. MS_EXCEPTION_IF_NULL(func_graph);
  866. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  867. for (auto const &node : node_lists) {
  868. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  869. continue;
  870. }
  871. auto cnode = node->cast<CNodePtr>();
  872. MS_EXCEPTION_IF_NULL(cnode);
  873. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  874. node_list->push_back(node);
  875. }
  876. }
  877. }
  878. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  879. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  880. MS_EXCEPTION_IF_NULL(node_list);
  881. MS_EXCEPTION_IF_NULL(input_list);
  882. MS_EXCEPTION_IF_NULL(output_list);
  883. MS_EXCEPTION_IF_NULL(func_graph);
  884. GetValidKernelNodes(func_graph, node_list);
  885. auto parameters = func_graph->parameters();
  886. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  887. auto func_output = func_graph->output();
  888. MS_EXCEPTION_IF_NULL(func_output);
  889. if (func_output->isa<CNode>()) {
  890. // multi output.
  891. auto cnode = func_output->cast<CNodePtr>();
  892. MS_EXCEPTION_IF_NULL(cnode);
  893. auto input0 = cnode->input(kAnfPrimitiveIndex);
  894. MS_EXCEPTION_IF_NULL(input0);
  895. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  896. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  897. auto input_node = cnode->input(input_idx);
  898. MS_EXCEPTION_IF_NULL(input_node);
  899. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  900. }
  901. } else {
  902. // single output.
  903. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  904. }
  905. } else {
  906. // single output.
  907. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  908. }
  909. }
  910. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  911. MS_EXCEPTION_IF_NULL(anf_node);
  912. MS_EXCEPTION_IF_NULL(node_json);
  913. auto cnode = anf_node->cast<CNodePtr>();
  914. MS_EXCEPTION_IF_NULL(cnode);
  915. if (input_idx + 1 >= cnode->size()) {
  916. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  917. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  918. }
  919. auto input_node = cnode->input(input_idx + 1);
  920. if (!IsValueNode<tensor::Tensor>(input_node)) {
  921. return false;
  922. }
  923. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  924. if (tensor == nullptr) {
  925. return false;
  926. }
  927. auto type_id = tensor->data_type();
  928. auto *data = tensor->data_c();
  929. MS_EXCEPTION_IF_NULL(data);
  930. if (tensor->DataDim() > 1 || tensor->DataSize() != 1) {
  931. // not const tensor.
  932. MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]";
  933. }
  934. if (type_id == kFloat32->type_id()) {
  935. float *val = static_cast<float *>(data);
  936. MS_EXCEPTION_IF_NULL(val);
  937. (*node_json)["value"] = val[0];
  938. MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
  939. return true;
  940. } else if (type_id == kFloat16->type_id()) {
  941. float16 *val = static_cast<float16 *>(data);
  942. MS_EXCEPTION_IF_NULL(val);
  943. (*node_json)["value"] = static_cast<float>(val[0]);
  944. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
  945. return true;
  946. } else if (type_id == kInt32->type_id()) {
  947. int *val = static_cast<int *>(data);
  948. MS_EXCEPTION_IF_NULL(val);
  949. (*node_json)["value"] = val[0];
  950. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
  951. return true;
  952. }
  953. MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  954. return false;
  955. }
  956. void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
  957. MS_EXCEPTION_IF_NULL(func_graph);
  958. MS_EXCEPTION_IF_NULL(node_list);
  959. auto output = func_graph->output();
  960. MS_EXCEPTION_IF_NULL(output);
  961. if (AnfAlgo::IsRealKernel(output)) {
  962. // single output.
  963. node_list->push_back(std::make_pair(output, 0));
  964. return;
  965. } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
  966. auto output_cnode = output->cast<CNodePtr>();
  967. MS_EXCEPTION_IF_NULL(output_cnode);
  968. // multi output.
  969. auto &inputs = output_cnode->inputs();
  970. for (size_t i = 1; i < inputs.size(); ++i) {
  971. auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
  972. node_list->push_back(in_with_idx);
  973. }
  974. return;
  975. }
  976. MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
  977. << " of graph: " << func_graph->ToString();
  978. }
  979. bool IsWeightBoundary(const AnfNodePtr &node) {
  980. if (node->isa<ValueNode>()) {
  981. return true;
  982. }
  983. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  984. return true;
  985. }
  986. return false;
  987. }
  988. void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params,
  989. size_t total_compute_size) {
  990. const size_t kThreadNum = 24;
  991. std::vector<std::thread> threads;
  992. threads.reserve(kThreadNum);
  993. size_t start = 0;
  994. size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum;
  995. while (start < total_compute_size) {
  996. size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
  997. threads.emplace_back(std::thread(func, params, start, end));
  998. start += once_compute_size;
  999. }
  1000. for (size_t i = 0; i < threads.size(); ++i) {
  1001. threads[i].join();
  1002. }
  1003. }
  1004. std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
  1005. if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
  1006. AnfAlgo::GetInputTensorNum(cnode) != 1) {
  1007. MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
  1008. << "] is not single input or single output ";
  1009. }
  1010. std::vector<int> axis;
  1011. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  1012. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  1013. MS_EXCEPTION_IF_NULL(primitive);
  1014. auto axis_attr = primitive->GetAttr(kAxis);
  1015. if (axis_attr == nullptr) {
  1016. MS_LOG(ERROR) << "This node does't have axie attr.";
  1017. return std::vector<int>();
  1018. }
  1019. auto type = axis_attr->type();
  1020. MS_EXCEPTION_IF_NULL(type);
  1021. std::vector<int> axis_list;
  1022. if (type->ToString() == kTypeInt32) {
  1023. axis_list.emplace_back(GetValue<int>(axis_attr));
  1024. } else {
  1025. axis_list = GetValue<std::vector<int>>(axis_attr);
  1026. }
  1027. for (const auto &elem : axis_list) {
  1028. if (elem < 0) {
  1029. axis.emplace_back(input_shape.size() + elem);
  1030. } else {
  1031. axis.emplace_back(elem);
  1032. }
  1033. }
  1034. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  1035. return axis;
  1036. }
  1037. } // namespace kernel
  1038. } // namespace mindspore