You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common_utils.cc 44 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "backend/kernel_compiler/common_utils.h"
  17. #include <unordered_map>
  18. #include <map>
  19. #include <iostream>
  20. #include <utility>
  21. #include <fstream>
  22. #include <algorithm>
  23. #include <thread>
  24. #include "nlohmann/json.hpp"
  25. #include "backend/session/anf_runtime_algorithm.h"
  26. #include "utils/ms_utils.h"
  27. #include "ir/manager.h"
  28. #include "ir/meta_tensor.h"
  29. #include "base/core_ops.h"
  30. #include "ir/graph_utils.h"
  31. namespace mindspore {
  32. namespace kernel {
  33. constexpr char kAxis[] = "axis";
  34. constexpr char kTypeInt32[] = "Int32";
  35. const std::unordered_map<std::string, TypeId> type_id_maps = {
  36. {"float", TypeId::kNumberTypeFloat32}, {"float16", TypeId::kNumberTypeFloat16},
  37. {"float32", TypeId::kNumberTypeFloat32}, {"float64", TypeId::kNumberTypeFloat64},
  38. {"int", TypeId::kNumberTypeInt}, {"int8", TypeId::kNumberTypeInt8},
  39. {"int16", TypeId::kNumberTypeInt16}, {"int32", TypeId::kNumberTypeInt32},
  40. {"int64", TypeId::kNumberTypeInt64}, {"uint", TypeId::kNumberTypeUInt},
  41. {"uint8", TypeId::kNumberTypeUInt8}, {"uint16", TypeId::kNumberTypeUInt16},
  42. {"uint32", TypeId::kNumberTypeUInt32}, {"uint64", TypeId::kNumberTypeUInt64},
  43. {"bool", TypeId::kNumberTypeBool},
  44. };
  45. const std::map<TypeId, std::string> type_id_str_map = {
  46. {TypeId::kNumberTypeFloat32, "float32"}, {TypeId::kNumberTypeFloat16, "float16"},
  47. {TypeId::kNumberTypeFloat, "float"}, {TypeId::kNumberTypeFloat64, "float64"},
  48. {TypeId::kNumberTypeInt, "int"}, {TypeId::kNumberTypeInt8, "int8"},
  49. {TypeId::kNumberTypeInt16, "int16"}, {TypeId::kNumberTypeInt32, "int32"},
  50. {TypeId::kNumberTypeInt64, "int64"}, {TypeId::kNumberTypeUInt, "uint"},
  51. {TypeId::kNumberTypeUInt8, "uint8"}, {TypeId::kNumberTypeUInt16, "uint16"},
  52. {TypeId::kNumberTypeUInt32, "uint32"}, {TypeId::kNumberTypeUInt64, "uint64"},
  53. {TypeId::kNumberTypeBool, "bool"},
  54. };
  55. const std::unordered_map<std::string, std::string> dtype_shortdtype_map_ = {
  56. {"float16", "f16"}, {"float32", "f32"}, {"float64", "f64"}, {"int8", "i8"}, {"int16", "i16"}, {"int32", "i32"},
  57. {"int64", "i64"}, {"uint8", "u8"}, {"uint16", "u16"}, {"uint32", "u32"}, {"uint64", "u64"}, {"bool", "bool"},
  58. };
  59. const std::unordered_map<std::string, size_t> dtype_nbyte_map = {
  60. {"float16", sizeof(float) / 2}, {"float32", sizeof(float)}, {"float64", sizeof(float) * 2},
  61. {"int8", sizeof(int) / 4}, {"int16", sizeof(int) / 2}, {"int32", sizeof(int)},
  62. {"int64", sizeof(int) * 2}, {"uint8", sizeof(int) / 4}, {"uint16", sizeof(int) / 2},
  63. {"uint32", sizeof(int)}, {"uint64", sizeof(int) * 2}, {"bool", sizeof(char)},
  64. };
  65. const std::unordered_map<std::string, FusionType> fusion_type_maps = {
  66. {"CONVLUTION", FusionType::CONVLUTION}, {"ELEMWISE", FusionType::ELEMWISE}, {"COMMREDUCE", FusionType::COMMREDUCE},
  67. {"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
  68. };
  69. void KernelMeta::Initialize(int pid) {
  70. if (pid == -1) {
  71. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
  72. } else {
  73. kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/";
  74. }
  75. // remove old kernel cache
  76. RemoveKernelCache();
  77. #if defined(_WIN32) || defined(_WIN64)
  78. auto ret = mkdir(kernel_meta_path_.c_str());
  79. #else
  80. auto ret = mkdir(kernel_meta_path_.c_str(), S_IRWXG | S_IRWXU);
  81. #endif
  82. if (ret != 0) {
  83. MS_LOG(INFO) << "kernel dir [" << kernel_meta_path_ << "], will be created later";
  84. }
  85. initialized_ = true;
  86. }
  87. void KernelMeta::RemoveKernelCache() {
  88. DIR *dir = opendir(kernel_meta_path_.c_str());
  89. if (dir == nullptr) {
  90. return;
  91. }
  92. struct dirent *entry;
  93. while ((entry = readdir(dir)) != nullptr) {
  94. std::string kernel_file = entry->d_name;
  95. std::string kernel_file_realpath = kernel_meta_path_ + kernel_file;
  96. (void)remove(kernel_file_realpath.c_str());
  97. }
  98. (void)closedir(dir);
  99. (void)rmdir(kernel_meta_path_.c_str());
  100. }
  101. std::string KernelMeta::Search(const std::string &kernel_name) const {
  102. if (!initialized_) {
  103. return "";
  104. }
  105. auto iter = kernel_meta_map_.find(kernel_name);
  106. if (iter == kernel_meta_map_.end()) {
  107. return "";
  108. } else {
  109. return iter->second;
  110. }
  111. }
  112. bool KernelMeta::Insert(const std::string &kernel_name, const std::string &kernel_json) {
  113. if (!initialized_) {
  114. return false;
  115. }
  116. kernel_meta_map_[kernel_name] = kernel_json;
  117. return true;
  118. }
  119. bool CheckCache(const std::string &kernel_name) {
  120. // check cache.
  121. KernelMeta *bin_map = KernelMeta::GetInstance();
  122. if (bin_map == nullptr) {
  123. MS_LOG(DEBUG) << "kernel cache is invalid.";
  124. return false;
  125. }
  126. std::string kernel_json = bin_map->Search(kernel_name);
  127. bool ret = (!kernel_json.empty());
  128. if (ret) {
  129. MS_LOG(INFO) << "Kernel name:" << kernel_name << " has registed.";
  130. } else {
  131. MS_LOG(INFO) << "Kernel name:" << kernel_name << " will been registed.";
  132. }
  133. return ret;
  134. }
  135. KernelPackPtr SearchCache(const std::string &kernel_name, const std::string &processor) {
  136. // search cache.
  137. KernelMeta *bin_map = KernelMeta::GetInstance();
  138. if (bin_map == nullptr) {
  139. MS_LOG(DEBUG) << "kernel cache is invalid.";
  140. return nullptr;
  141. }
  142. std::string kernel_json = bin_map->Search(kernel_name);
  143. if (!kernel_json.empty()) {
  144. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  145. // just a tmp solution.
  146. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  147. MS_LOG(DEBUG) << "Read cache json and bin file failed[" << kernel_json << "].";
  148. return nullptr;
  149. } else {
  150. return kernel_pack;
  151. }
  152. } else {
  153. MS_LOG(INFO) << "cache kernel not found[" << kernel_name << "].";
  154. return nullptr;
  155. }
  156. }
  157. KernelPackPtr InsertCache(const std::string &kernel_name, const std::string &processor) {
  158. MS_LOG(INFO) << "kernel name:" << kernel_name << ", processr:" << processor;
  159. KernelMeta *bin_map = KernelMeta::GetInstance();
  160. std::string kernel_json;
  161. if (processor == kProcessorAiCore || processor == kProcessorAiCpu) {
  162. kernel_json = kCceKernelMeta;
  163. } else {
  164. kernel_json = bin_map->kernel_meta_path();
  165. }
  166. (void)kernel_json.append(kernel_name).append(kJsonSuffix);
  167. KernelPackPtr kernel_pack = std::make_shared<KernelPack>();
  168. if (!kernel_pack->ReadFromJsonFile(kernel_json, processor)) {
  169. MS_LOG(DEBUG) << "Read json and bin file failed[" << kernel_json << "].";
  170. return nullptr;
  171. }
  172. if (bin_map == nullptr) {
  173. MS_LOG(DEBUG) << "kernel cache is invalid.";
  174. return nullptr;
  175. }
  176. if (bin_map->Insert(kernel_name, kernel_json)) {
  177. MS_LOG(INFO) << "Insert to cache success[" << kernel_json << "], kernelname[" << kernel_name << "].";
  178. }
  179. return kernel_pack;
  180. }
  181. TypeId DtypeToTypeId(const std::string &dtypes) {
  182. auto iter = type_id_maps.find(dtypes);
  183. if (iter != type_id_maps.end()) {
  184. return iter->second;
  185. } else {
  186. MS_EXCEPTION(ArgumentError) << "Illegal input device dtype:" << dtypes;
  187. }
  188. }
  189. std::string TypeId2String(TypeId type_id) {
  190. auto iter = type_id_str_map.find(type_id);
  191. if (iter == type_id_str_map.end()) {
  192. return std::string(TypeIdLabel(type_id));
  193. }
  194. return iter->second;
  195. }
  196. std::string Dtype2ShortType(const std::string &dtypes) {
  197. auto iter = dtype_shortdtype_map_.find(dtypes);
  198. if (iter != dtype_shortdtype_map_.end()) {
  199. return iter->second;
  200. } else {
  201. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  202. }
  203. }
  204. size_t GetDtypeNbyte(const std::string &dtypes) {
  205. auto iter = dtype_nbyte_map.find(dtypes);
  206. if (iter != dtype_nbyte_map.end()) {
  207. return iter->second;
  208. } else {
  209. MS_EXCEPTION(ArgumentError) << "Illegal input dtype:" << dtypes;
  210. }
  211. }
  212. bool SetInputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &inputs, size_t real_input_num,
  213. size_t builder_idex, const std::vector<int> &dyn_input_sizes,
  214. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  215. MS_EXCEPTION_IF_NULL(builder);
  216. std::vector<TypeId> inputs_device_type;
  217. std::vector<std::string> inputs_format;
  218. size_t dyn_input_idx = 0;
  219. size_t kernel_info_index = 0;
  220. MS_EXCEPTION_IF_NULL(inputs[0]);
  221. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  222. for (const auto &input : inputs) {
  223. MS_EXCEPTION_IF_NULL(input);
  224. std::string param_type = input->param_type();
  225. std::vector<std::string> dtypes = input->dtypes();
  226. std::vector<std::string> formats = input->formats();
  227. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  228. MS_LOG(DEBUG) << "Set input kernel builder info, dtyps size != formats size.";
  229. return false;
  230. }
  231. if (param_type == "dynamic") {
  232. if (dyn_input_sizes.empty()) {
  233. MS_LOG(DEBUG) << "Set input kernel builder info, dyn_input_sizes's size is 0 when param_type is dynamic";
  234. return false;
  235. }
  236. for (int t = 0; t < dyn_input_sizes[dyn_input_idx]; t++) {
  237. kernel_info_index++;
  238. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  239. inputs_device_type.push_back(type_id);
  240. inputs_format.push_back(formats[builder_idex]);
  241. }
  242. dyn_input_idx++;
  243. } else if (param_type == "required") {
  244. kernel_info_index++;
  245. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  246. inputs_device_type.push_back(type_id);
  247. inputs_format.push_back(formats[builder_idex]);
  248. } else {
  249. if (kernel_info_index < real_input_num) {
  250. MS_LOG(INFO) << "Set input kernel builder info, input type is optional, input index is :" << kernel_info_index;
  251. kernel_info_index++;
  252. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  253. inputs_device_type.push_back(type_id);
  254. inputs_format.push_back(formats[builder_idex]);
  255. }
  256. }
  257. }
  258. builder->SetInputsDeviceType(inputs_device_type);
  259. builder->SetInputsFormat(inputs_format);
  260. return true;
  261. }
  262. bool SetOutputKernelBuilderInfo(const std::vector<std::shared_ptr<OpIOInfo>> &outputs, size_t builder_idex,
  263. const size_t &real_output_num,
  264. const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder) {
  265. // not now but in the next we need to support dynamic output case
  266. MS_EXCEPTION_IF_NULL(builder);
  267. size_t output_idx = 0;
  268. std::vector<TypeId> outputs_device_type;
  269. std::vector<std::string> outputs_format;
  270. MS_EXCEPTION_IF_NULL(outputs[0]);
  271. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  272. for (const auto &output : outputs) {
  273. MS_EXCEPTION_IF_NULL(output);
  274. if (output_idx >= real_output_num) {
  275. MS_LOG(DEBUG) << "real_output_num:" << real_output_num << ", output_idx:" << output_idx << " is out of limit!";
  276. continue;
  277. }
  278. size_t output_num = 0;
  279. if (output->param_type() == "dynamic") {
  280. if (outputs.size() > 1) {
  281. MS_EXCEPTION(ArgumentError) << "Dynamic output is unsupported multi output!";
  282. }
  283. output_num = real_output_num;
  284. } else if (output->param_type() == "required") {
  285. output_num = 1;
  286. } else {
  287. if (output_idx < real_output_num) {
  288. MS_LOG(DEBUG) << "Set output kernel builder info, output type is optional, output index is :" << output_idx;
  289. output_num = 1;
  290. }
  291. }
  292. for (size_t i = 0; i < output_num; i++) {
  293. std::vector<std::string> dtypes = output->dtypes();
  294. std::vector<std::string> formats = output->formats();
  295. if (dtypes.size() != kernel_info_cnt || formats.size() != kernel_info_cnt) {
  296. MS_LOG(DEBUG) << "Set output kernel builder info, dtyps size != formats size.";
  297. return false;
  298. }
  299. auto type_id = DtypeToTypeId(dtypes[builder_idex]);
  300. outputs_device_type.push_back(type_id);
  301. outputs_format.push_back(formats[builder_idex]);
  302. output_idx++;
  303. }
  304. }
  305. builder->SetOutputsFormat(outputs_format);
  306. builder->SetOutputsDeviceType(outputs_device_type);
  307. return true;
  308. }
  309. void SetKernelBuildInfo(const std::shared_ptr<KernelBuildInfo::KernelBuildInfoBuilder> &builder, Processor processor,
  310. const std::shared_ptr<const OpInfo> &op_info_ptr) {
  311. MS_EXCEPTION_IF_NULL(builder);
  312. MS_EXCEPTION_IF_NULL(op_info_ptr);
  313. auto imply_type = op_info_ptr->imply_type();
  314. builder->SetProcessor(processor);
  315. std::string fusion_type = op_info_ptr->fusion_type();
  316. auto iter = fusion_type_maps.find(fusion_type);
  317. if (iter != fusion_type_maps.end()) {
  318. builder->SetFusionType(iter->second);
  319. } else {
  320. if (imply_type == kAKG) {
  321. MS_EXCEPTION(NotExistsError) << "Illegal fusion type from dsl register:" << fusion_type;
  322. }
  323. }
  324. if (imply_type == kAKG) {
  325. builder->SetKernelType(AKG_KERNEL);
  326. } else if (imply_type == kAICPU) {
  327. builder->SetKernelType(AICPU_KERNEL);
  328. } else {
  329. builder->SetKernelType(TBE_KERNEL);
  330. }
  331. }
  332. bool ParseMetadata(const CNodePtr &kernel_node, const std::shared_ptr<const OpInfo> &op_info_ptr, Processor processor,
  333. std::vector<std::shared_ptr<KernelBuildInfo>> *const kernel_info_list) {
  334. MS_EXCEPTION_IF_NULL(kernel_node);
  335. MS_EXCEPTION_IF_NULL(kernel_info_list);
  336. size_t real_input_num = AnfAlgo::GetInputTensorNum(kernel_node);
  337. size_t real_output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
  338. std::vector<std::shared_ptr<OpIOInfo>> inputs = op_info_ptr->inputs_ptr();
  339. std::vector<std::shared_ptr<OpIOInfo>> outputs = op_info_ptr->outputs_ptr();
  340. std::vector<int> dyn_input_sizes;
  341. auto primitive = AnfAlgo::GetCNodePrimitive(kernel_node);
  342. MS_EXCEPTION_IF_NULL(primitive);
  343. if (primitive->GetAttr("dyn_input_sizes") != nullptr) {
  344. dyn_input_sizes = GetValue<std::vector<int>>(primitive->GetAttr("dyn_input_sizes"));
  345. }
  346. if (inputs.size() > 0) {
  347. MS_EXCEPTION_IF_NULL(inputs[0]);
  348. size_t kernel_info_cnt = inputs[0]->dtypes().size();
  349. for (size_t j = 0; j < kernel_info_cnt; j++) {
  350. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  351. MS_EXCEPTION_IF_NULL(builder);
  352. SetKernelBuildInfo(builder, processor, op_info_ptr);
  353. if (!SetInputKernelBuilderInfo(inputs, real_input_num, j, dyn_input_sizes, builder)) {
  354. MS_LOG(DEBUG) << "Parse kernel metadata, set inputs kernel builder info failed.";
  355. return false;
  356. }
  357. if (outputs.size() > 0) {
  358. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  359. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  360. return false;
  361. }
  362. }
  363. kernel_info_list->push_back(builder->Build());
  364. }
  365. } else if (outputs.size() > 0) {
  366. MS_EXCEPTION_IF_NULL(outputs[0]);
  367. size_t kernel_info_cnt = outputs[0]->dtypes().size();
  368. for (size_t j = 0; j < kernel_info_cnt; j++) {
  369. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  370. MS_EXCEPTION_IF_NULL(builder);
  371. SetKernelBuildInfo(builder, processor, op_info_ptr);
  372. if (!SetOutputKernelBuilderInfo(outputs, j, real_output_num, builder)) {
  373. MS_LOG(DEBUG) << "Parse kernel metadata, set outputs kernel builder info failed.";
  374. return false;
  375. }
  376. kernel_info_list->push_back(builder->Build());
  377. }
  378. } else {
  379. if (processor == AICPU) {
  380. auto builder = std::make_shared<KernelBuildInfo::KernelBuildInfoBuilder>();
  381. MS_EXCEPTION_IF_NULL(builder);
  382. SetKernelBuildInfo(builder, processor, op_info_ptr);
  383. kernel_info_list->push_back(builder->Build());
  384. }
  385. }
  386. return true;
  387. }
  388. void SaveJsonInfo(const std::string &json_name, const std::string &info) {
  389. char real_path[PATH_MAX] = {0};
  390. std::string path = kCceKernelMeta + json_name + kInfoSuffix;
  391. if (path.size() > PATH_MAX) {
  392. MS_LOG(DEBUG) << "file path " << path << " is too long.";
  393. return;
  394. }
  395. std::ofstream filewrite;
  396. filewrite.open(path);
  397. if (!filewrite.is_open()) {
  398. return;
  399. }
  400. filewrite << info << std::endl;
  401. filewrite.close();
  402. #if defined(_WIN32) || defined(_WIN64)
  403. if (nullptr == _fullpath(real_path, path.c_str(), PATH_MAX)) {
  404. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  405. return;
  406. }
  407. #else
  408. if (nullptr == realpath(path.c_str(), real_path)) {
  409. MS_LOG(DEBUG) << "dir " << path << " does not exit.";
  410. return;
  411. }
  412. #endif
  413. MS_LOG(INFO) << "real path is :" << real_path;
  414. if (chmod(real_path, S_IRUSR) == -1) {
  415. MS_LOG(DEBUG) << "modify file:" << real_path << " to read only fail.";
  416. }
  417. }
  418. std::string GetProcessor(const AnfNodePtr &anf_node) {
  419. MS_EXCEPTION_IF_NULL(anf_node);
  420. std::string device;
  421. switch (AnfAlgo::GetProcessor(anf_node)) {
  422. case Processor::AICORE:
  423. device = kProcessorAiCore;
  424. break;
  425. case Processor::AICPU:
  426. device = kProcessorAiCpu;
  427. break;
  428. case Processor::CUDA:
  429. device = kProcessorCuda;
  430. break;
  431. default:
  432. MS_LOG(DEBUG) << "Unknown processor type.";
  433. break;
  434. }
  435. return device;
  436. }
  437. bool IsSameShape(const std::vector<size_t> &shape_a, const std::vector<size_t> &shape_b) {
  438. if (shape_a.size() != shape_b.size()) {
  439. return false;
  440. }
  441. for (size_t i = 0; i < shape_a.size(); ++i) {
  442. if (shape_a[i] != shape_b[i]) {
  443. return false;
  444. }
  445. }
  446. return true;
  447. }
  448. int Sign(float x) {
  449. if (x > 0) {
  450. return 1;
  451. }
  452. if (x < 0) {
  453. return -1;
  454. }
  455. return 0;
  456. }
  457. namespace {
  458. struct BucketSparseGradient {
  459. float *value_;
  460. int *indices_;
  461. int *global_indices_;
  462. size_t indices_size_;
  463. };
  464. struct MultiThreadReduceSparseGradientParam {
  465. SparseGradient *input_grad_{nullptr};
  466. SparseGradient *workspace_grad_{nullptr};
  467. SparseGradient *output_grad_{nullptr};
  468. size_t max_index_{0};
  469. size_t value_stride_{0};
  470. size_t thread_num_{0};
  471. bool use_sort_reduce_{false};
  472. };
  473. void CalculateEachBucketSize(const std::shared_ptr<SparseGradient> &sparse_grad, size_t max_index,
  474. std::vector<size_t> *each_bucket_size) {
  475. MS_LOG(DEBUG) << "Start";
  476. MS_EXCEPTION_IF_NULL(sparse_grad);
  477. MS_EXCEPTION_IF_NULL(sparse_grad->indices_);
  478. MS_EXCEPTION_IF_NULL(each_bucket_size);
  479. size_t bucket_num = each_bucket_size->size();
  480. for (size_t i = 0; i < sparse_grad->indices_size_; ++i) {
  481. int index = sparse_grad->indices_[i];
  482. if (index >= 0 && IntToSize(index) < max_index) {
  483. auto bucket_id = index % bucket_num;
  484. each_bucket_size->at(bucket_id)++;
  485. }
  486. }
  487. MS_LOG(DEBUG) << "End";
  488. }
  489. void SplitAndCalculateSegmentBucketSize(const MultiThreadReduceSparseGradientParam &param,
  490. std::vector<std::shared_ptr<SparseGradient>> *segments_ptr,
  491. std::vector<std::shared_ptr<std::vector<size_t>>> *segment_bucket_sizes_ptr) {
  492. MS_EXCEPTION_IF_NULL(param.input_grad_);
  493. MS_EXCEPTION_IF_NULL(segment_bucket_sizes_ptr);
  494. MS_EXCEPTION_IF_NULL(segments_ptr);
  495. auto &segments = *segments_ptr;
  496. auto &segment_bucket_sizes = *segment_bucket_sizes_ptr;
  497. auto input_grad = param.input_grad_;
  498. if (param.thread_num_ < 1) {
  499. MS_EXCEPTION(ArgumentError) << "Input param thread num must > 0!";
  500. }
  501. size_t thread_indices_size = input_grad->indices_size_ / param.thread_num_;
  502. size_t left_indices_size = input_grad->indices_size_ % param.thread_num_;
  503. std::vector<std::thread> threads;
  504. threads.reserve(param.thread_num_);
  505. segments.reserve(param.thread_num_);
  506. size_t current_indices_offset = 0;
  507. for (size_t i = 0; i < param.thread_num_; ++i) {
  508. segment_bucket_sizes.emplace_back(std::make_shared<std::vector<size_t>>(param.thread_num_, 0));
  509. size_t indices_size = thread_indices_size;
  510. if (i < left_indices_size) {
  511. indices_size += 1;
  512. }
  513. segments.emplace_back(std::make_shared<SparseGradient>());
  514. segments[i]->value_ = input_grad->value_ + current_indices_offset * param.value_stride_;
  515. segments[i]->indices_ = input_grad->indices_ + current_indices_offset;
  516. segments[i]->indices_size_ = indices_size;
  517. threads.emplace_back(
  518. std::thread(CalculateEachBucketSize, segments[i], param.max_index_, segment_bucket_sizes[i].get()));
  519. current_indices_offset += indices_size;
  520. }
  521. for (size_t i = 0; i < param.thread_num_; ++i) {
  522. threads[i].join();
  523. }
  524. }
  525. void CopySegmentIndicesToBucket(const MultiThreadReduceSparseGradientParam &param,
  526. const std::shared_ptr<SparseGradient> &segment, size_t bucket_offset,
  527. const std::vector<std::shared_ptr<BucketSparseGradient>> &buckets) {
  528. MS_LOG(DEBUG) << "Start";
  529. MS_EXCEPTION_IF_NULL(segment);
  530. MS_EXCEPTION_IF_NULL(segment->indices_);
  531. std::vector<size_t> bucket_data_num(param.thread_num_, 0);
  532. for (size_t i = 0; i < segment->indices_size_; ++i) {
  533. int index = segment->indices_[i];
  534. if (index >= 0 && IntToSize(index) < param.max_index_) {
  535. auto bucket_id = index % param.thread_num_;
  536. auto bucket_index = bucket_data_num[bucket_id];
  537. buckets[bucket_id]->indices_[bucket_index] = index;
  538. buckets[bucket_id]->global_indices_[bucket_index] = bucket_offset + i;
  539. bucket_data_num[bucket_id]++;
  540. }
  541. }
  542. MS_LOG(DEBUG) << "End";
  543. }
  544. void GatherSegmentIndicesToOutputBucket(const MultiThreadReduceSparseGradientParam &param,
  545. const std::vector<std::shared_ptr<SparseGradient>> &segments,
  546. const std::vector<std::shared_ptr<std::vector<size_t>>> &segment_bucket_sizes,
  547. std::vector<std::shared_ptr<BucketSparseGradient>> *buckets_ptr) {
  548. MS_EXCEPTION_IF_NULL(param.output_grad_);
  549. MS_EXCEPTION_IF_NULL(param.output_grad_->value_);
  550. MS_EXCEPTION_IF_NULL(param.output_grad_->indices_);
  551. MS_EXCEPTION_IF_NULL(buckets_ptr);
  552. auto &buckets = *buckets_ptr;
  553. size_t thread_num = param.thread_num_;
  554. if (thread_num != segment_bucket_sizes.size()) {
  555. MS_EXCEPTION(ArgumentError) << "Input param thread num not equal to segment size!";
  556. }
  557. std::vector<size_t> bucket_data_size(thread_num, 0);
  558. for (size_t i = 0; i < thread_num; ++i) {
  559. for (size_t j = 0; j < thread_num; ++j) {
  560. bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
  561. }
  562. }
  563. size_t current_indices_offset = 0;
  564. for (size_t i = 0; i < thread_num; ++i) {
  565. buckets.emplace_back(std::make_shared<BucketSparseGradient>());
  566. buckets[i]->value_ = param.output_grad_->value_ + current_indices_offset * param.value_stride_;
  567. buckets[i]->indices_ = param.output_grad_->indices_ + current_indices_offset;
  568. buckets[i]->global_indices_ = param.workspace_grad_->indices_ + current_indices_offset;
  569. buckets[i]->indices_size_ = bucket_data_size[i];
  570. current_indices_offset += bucket_data_size[i];
  571. }
  572. std::vector<size_t> tmp_bucket_data_size(thread_num, 0);
  573. std::vector<std::vector<std::shared_ptr<BucketSparseGradient>>> each_thread_buckets;
  574. for (size_t i = 0; i < thread_num; ++i) {
  575. std::vector<std::shared_ptr<BucketSparseGradient>> thread_buckets;
  576. for (size_t j = 0; j < thread_num; ++j) {
  577. thread_buckets.emplace_back(std::make_shared<BucketSparseGradient>());
  578. thread_buckets[j]->indices_ = buckets[j]->indices_ + tmp_bucket_data_size[j];
  579. thread_buckets[j]->global_indices_ = buckets[j]->global_indices_ + tmp_bucket_data_size[j];
  580. thread_buckets[j]->value_ = buckets[j]->value_ + tmp_bucket_data_size[j] * param.value_stride_;
  581. thread_buckets[j]->indices_size_ = segment_bucket_sizes[i]->at(j);
  582. tmp_bucket_data_size[j] += segment_bucket_sizes[i]->at(j);
  583. }
  584. each_thread_buckets.emplace_back(thread_buckets);
  585. }
  586. std::vector<std::thread> threads;
  587. threads.reserve(thread_num);
  588. current_indices_offset = 0;
  589. for (size_t i = 0; i < thread_num; ++i) {
  590. threads.emplace_back(
  591. std::thread(CopySegmentIndicesToBucket, param, segments[i], current_indices_offset, each_thread_buckets[i]));
  592. current_indices_offset += segments[i]->indices_size_;
  593. }
  594. for (size_t i = 0; i < thread_num; ++i) {
  595. threads[i].join();
  596. }
  597. }
  598. void SortAndReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam &param,
  599. const std::shared_ptr<BucketSparseGradient> &bucket,
  600. const std::shared_ptr<SparseGradient> &reduced_bucket) {
  601. MS_LOG(DEBUG) << "Start";
  602. MS_EXCEPTION_IF_NULL(bucket);
  603. MS_EXCEPTION_IF_NULL(bucket->value_);
  604. MS_EXCEPTION_IF_NULL(bucket->indices_);
  605. MS_EXCEPTION_IF_NULL(reduced_bucket);
  606. MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
  607. MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
  608. std::vector<std::pair<int, int>> sorted_indices;
  609. sorted_indices.reserve(bucket->indices_size_);
  610. for (size_t i = 0; i < bucket->indices_size_; ++i) {
  611. int index = bucket->indices_[i];
  612. int global_index = bucket->global_indices_[i];
  613. sorted_indices.emplace_back(std::pair<int, int>(index, global_index));
  614. }
  615. std::sort(sorted_indices.begin(), sorted_indices.end());
  616. float *global_value = param.input_grad_->value_;
  617. size_t unique_indices_size = 0;
  618. size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
  619. int last_index{0};
  620. size_t value_offset{0};
  621. for (size_t i = 0; i < sorted_indices.size(); ++i) {
  622. int index = sorted_indices[i].first;
  623. int global_index = sorted_indices[i].second;
  624. int global_value_offset = global_index * param.value_stride_;
  625. if (i == 0 || index != last_index) {
  626. if (i != 0) {
  627. unique_indices_size++;
  628. }
  629. reduced_bucket->indices_[unique_indices_size] = index;
  630. value_offset = unique_indices_size * param.value_stride_;
  631. auto ret_code = memcpy_s(reduced_bucket->value_ + value_offset, (max_length - value_offset) * sizeof(float),
  632. global_value + global_value_offset, param.value_stride_ * sizeof(float));
  633. if (ret_code != EOK) {
  634. MS_LOG(EXCEPTION) << "Failed to copy data!";
  635. }
  636. } else {
  637. for (size_t j = 0; j < param.value_stride_; ++j) {
  638. reduced_bucket->value_[value_offset + j] += global_value[global_value_offset + j];
  639. }
  640. }
  641. last_index = index;
  642. }
  643. reduced_bucket->indices_size_ = unique_indices_size;
  644. MS_LOG(DEBUG) << "End";
  645. }
  646. void ReduceBucketSparseGradient(const MultiThreadReduceSparseGradientParam &param,
  647. const std::shared_ptr<BucketSparseGradient> &bucket,
  648. const std::shared_ptr<SparseGradient> &reduced_bucket) {
  649. MS_LOG(DEBUG) << "Start";
  650. MS_EXCEPTION_IF_NULL(bucket);
  651. MS_EXCEPTION_IF_NULL(bucket->value_);
  652. MS_EXCEPTION_IF_NULL(bucket->indices_);
  653. MS_EXCEPTION_IF_NULL(reduced_bucket);
  654. MS_EXCEPTION_IF_NULL(reduced_bucket->value_);
  655. MS_EXCEPTION_IF_NULL(reduced_bucket->indices_);
  656. float *global_value = param.input_grad_->value_;
  657. std::unordered_map<int, size_t> index_map;
  658. size_t unique_indices_size = 0;
  659. size_t max_length = reduced_bucket->indices_size_ * param.value_stride_;
  660. for (size_t i = 0; i < bucket->indices_size_; ++i) {
  661. int index = bucket->indices_[i];
  662. int global_index = bucket->global_indices_[i];
  663. auto iter = index_map.find(index);
  664. if (iter == index_map.end()) {
  665. reduced_bucket->indices_[unique_indices_size] = index;
  666. size_t start_index = unique_indices_size * param.value_stride_;
  667. index_map[index] = start_index;
  668. auto ret_code = memcpy_s(reduced_bucket->value_ + start_index, (max_length - start_index) * sizeof(float),
  669. global_value + global_index * param.value_stride_, param.value_stride_ * sizeof(float));
  670. if (ret_code != EOK) {
  671. MS_LOG(EXCEPTION) << "Failed to copy data!";
  672. }
  673. unique_indices_size++;
  674. } else {
  675. size_t start_index = iter->second;
  676. size_t end_index = start_index + param.value_stride_;
  677. for (size_t j = start_index, k = global_index * param.value_stride_; j < end_index; ++j, ++k) {
  678. reduced_bucket->value_[j] += global_value[k];
  679. }
  680. }
  681. }
  682. reduced_bucket->indices_size_ = unique_indices_size;
  683. MS_LOG(DEBUG) << "End";
  684. }
  685. void ReduceBucketSparseGradientToWorkspace(const MultiThreadReduceSparseGradientParam &param,
  686. const std::vector<std::shared_ptr<BucketSparseGradient>> &buckets,
  687. std::vector<std::shared_ptr<SparseGradient>> *reduced_buckets_ptr) {
  688. MS_EXCEPTION_IF_NULL(param.workspace_grad_);
  689. MS_EXCEPTION_IF_NULL(param.workspace_grad_->value_);
  690. MS_EXCEPTION_IF_NULL(param.workspace_grad_->indices_);
  691. MS_EXCEPTION_IF_NULL(reduced_buckets_ptr);
  692. auto &reduced_buckets = *reduced_buckets_ptr;
  693. size_t thread_num = buckets.size();
  694. std::vector<std::thread> threads;
  695. threads.reserve(thread_num);
  696. size_t current_indices_offset = 0;
  697. for (size_t i = 0; i < thread_num; ++i) {
  698. reduced_buckets.emplace_back(std::make_shared<SparseGradient>());
  699. reduced_buckets[i]->value_ = param.workspace_grad_->value_ + current_indices_offset * param.value_stride_;
  700. reduced_buckets[i]->indices_ = param.workspace_grad_->indices_ + current_indices_offset;
  701. reduced_buckets[i]->indices_size_ = buckets[i]->indices_size_;
  702. if (param.use_sort_reduce_) {
  703. threads.emplace_back(std::thread(SortAndReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i]));
  704. } else {
  705. threads.emplace_back(std::thread(ReduceBucketSparseGradient, param, buckets[i], reduced_buckets[i]));
  706. }
  707. current_indices_offset += buckets[i]->indices_size_;
  708. }
  709. for (size_t i = 0; i < thread_num; ++i) {
  710. threads[i].join();
  711. }
  712. }
  713. void MergeReduceSparseGradient(const MultiThreadReduceSparseGradientParam &param,
  714. const std::vector<std::shared_ptr<SparseGradient>> &reduced_buckets) {
  715. MS_EXCEPTION_IF_NULL(param.output_grad_);
  716. auto output_grad = param.output_grad_;
  717. MS_EXCEPTION_IF_NULL(output_grad->value_);
  718. MS_EXCEPTION_IF_NULL(output_grad->indices_);
  719. size_t stride_data_size = param.value_stride_ * sizeof(float);
  720. size_t unique_indices_size = 0;
  721. for (size_t i = 0; i < reduced_buckets.size(); ++i) {
  722. auto &bucket = reduced_buckets[i];
  723. MS_EXCEPTION_IF_NULL(bucket);
  724. if (bucket->indices_size_ == 0) {
  725. continue;
  726. }
  727. auto ret_code = memcpy_s(output_grad->value_ + unique_indices_size * param.value_stride_,
  728. (output_grad->indices_size_ - unique_indices_size) * stride_data_size, bucket->value_,
  729. bucket->indices_size_ * stride_data_size);
  730. if (ret_code != EOK) {
  731. MS_LOG(EXCEPTION) << "Failed to copy data!";
  732. }
  733. ret_code = memcpy_s(output_grad->indices_ + unique_indices_size,
  734. (output_grad->indices_size_ - unique_indices_size) * sizeof(int), bucket->indices_,
  735. bucket->indices_size_ * sizeof(int));
  736. if (ret_code != EOK) {
  737. MS_LOG(EXCEPTION) << "Failed to copy data!";
  738. }
  739. unique_indices_size += bucket->indices_size_;
  740. }
  741. output_grad->indices_size_ = unique_indices_size;
  742. }
  743. } // namespace
  744. void BucketReduceSparseGradient(const ReduceSparseGradientParam &param) {
  745. MS_LOG(DEBUG) << "Start";
  746. MS_EXCEPTION_IF_NULL(param.input_grad_);
  747. size_t thread_num = 23;
  748. if (param.input_grad_->indices_size_ < thread_num) {
  749. thread_num = param.input_grad_->indices_size_;
  750. }
  751. MultiThreadReduceSparseGradientParam multi_thread_param({param.input_grad_, param.workspace_grad_, param.output_grad_,
  752. param.max_index_, param.value_stride_, thread_num,
  753. param.use_sort_reduce_});
  754. std::vector<std::shared_ptr<SparseGradient>> segments;
  755. std::vector<std::shared_ptr<std::vector<size_t>>> segment_bucket_sizes;
  756. SplitAndCalculateSegmentBucketSize(multi_thread_param, &segments, &segment_bucket_sizes);
  757. std::vector<std::shared_ptr<BucketSparseGradient>> buckets;
  758. GatherSegmentIndicesToOutputBucket(multi_thread_param, segments, segment_bucket_sizes, &buckets);
  759. std::vector<std::shared_ptr<SparseGradient>> reduced_buckets;
  760. ReduceBucketSparseGradientToWorkspace(multi_thread_param, buckets, &reduced_buckets);
  761. MergeReduceSparseGradient(multi_thread_param, reduced_buckets);
  762. MS_LOG(DEBUG) << "End";
  763. }
  764. std::pair<AnfNodePtr, size_t> GetKernelInput(const AnfNodePtr &anf_node, size_t index) {
  765. MS_EXCEPTION_IF_NULL(anf_node);
  766. if (index >= AnfAlgo::GetInputTensorNum(anf_node)) {
  767. MS_EXCEPTION(ArgumentError) << "Index is out of the size of anf_node inputs.";
  768. }
  769. auto cnode = anf_node->cast<CNodePtr>();
  770. if (cnode == nullptr) {
  771. return AnfAlgo::VisitKernel(anf_node, 0);
  772. } else {
  773. return AnfAlgo::VisitKernel(anf_node->cast<CNodePtr>()->input(index + 1), 0);
  774. }
  775. }
  776. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> GetInputIndex(const std::vector<AnfNodePtr> &node_list,
  777. const std::vector<AnfNodePtr> &input_list) {
  778. std::vector<std::pair<AnfNodePtr, std::pair<size_t, size_t>>> input_index;
  779. for (size_t i = 0; i < input_list.size(); ++i) {
  780. auto const &input = input_list[i];
  781. MS_EXCEPTION_IF_NULL(input);
  782. bool found = false;
  783. // using NodeUsersMap = std::unordered_map<AnfNodePtr, std::set<std::pair<AnfNodePtr, int>>>;
  784. auto mng = input->func_graph()->manager();
  785. MS_EXCEPTION_IF_NULL(mng);
  786. const NodeUsersMap &users = mng->node_users();
  787. auto input_users = users.find(input);
  788. if (input_users == users.end() || input_users->second.empty()) {
  789. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  790. << input->func_graph()->ToString() << "] has no users.";
  791. }
  792. for (auto const &input_user : input_users->second) {
  793. for (auto const &anf_node : node_list) {
  794. if (anf_node != input_user.first) {
  795. continue;
  796. }
  797. std::vector<int> dyn_input_sizes;
  798. auto prim = AnfAlgo::GetCNodePrimitive(anf_node);
  799. MS_EXCEPTION_IF_NULL(prim);
  800. if (prim->GetAttr(kAttrDynInputSizes) != nullptr) {
  801. dyn_input_sizes = GetValue<const std::vector<int>>(prim->GetAttr(kAttrDynInputSizes));
  802. }
  803. if (dyn_input_sizes.empty()) {
  804. input_index.push_back(std::make_pair(anf_node, std::make_pair(IntToSize(input_user.second - 1), 0)));
  805. found = true;
  806. break;
  807. } else {
  808. int used_as_idx = input_user.second - 1;
  809. int accum_idx = 0;
  810. size_t dyn_i = 0;
  811. for (; dyn_i < dyn_input_sizes.size(); ++dyn_i) {
  812. accum_idx += dyn_input_sizes[dyn_i];
  813. if (used_as_idx < accum_idx) {
  814. input_index.push_back(std::make_pair(
  815. anf_node, std::make_pair(dyn_i, IntToSize(used_as_idx - (accum_idx - dyn_input_sizes[dyn_i])))));
  816. break;
  817. }
  818. }
  819. if (dyn_i != dyn_input_sizes.size()) {
  820. found = true;
  821. break;
  822. }
  823. }
  824. }
  825. if (found) {
  826. break;
  827. }
  828. }
  829. if (!found) {
  830. MS_EXCEPTION(ArgumentError) << "Input [" << i << "][" << input->DebugString(2) << "] of ["
  831. << input->func_graph()->ToString() << "] found no related kernel info.";
  832. }
  833. }
  834. return input_index;
  835. }
  836. std::vector<std::pair<AnfNodePtr, size_t>> GetOutputIndex(const std::vector<AnfNodePtr> &node_list,
  837. const std::vector<AnfNodePtr> &input_list,
  838. const std::vector<AnfNodePtr> &output_list) {
  839. std::vector<std::pair<AnfNodePtr, size_t>> output_index;
  840. for (size_t i = 0; i < output_list.size(); ++i) {
  841. auto const &output = output_list[i];
  842. MS_EXCEPTION_IF_NULL(output);
  843. bool found = false;
  844. auto pree_node = AnfAlgo::VisitKernel(output, 0);
  845. auto pos = std::find(std::begin(node_list), std::end(node_list), pree_node.first);
  846. if (pos != std::end(node_list)) {
  847. output_index.push_back(pree_node);
  848. continue;
  849. }
  850. auto ret = std::find(std::begin(input_list), std::end(input_list), pree_node.first);
  851. if (ret != std::end(input_list)) {
  852. output_index.push_back(std::make_pair(pree_node.first, 0));
  853. found = true;
  854. }
  855. if (!found) {
  856. MS_EXCEPTION(ArgumentError) << "Output [" << i << "][" << output->DebugString(2) << "] of ["
  857. << output->func_graph()->ToString() << "] found no related kernel info.";
  858. }
  859. }
  860. return output_index;
  861. }
  862. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list) {
  863. MS_EXCEPTION_IF_NULL(node_list);
  864. MS_EXCEPTION_IF_NULL(func_graph);
  865. std::vector<AnfNodePtr> node_lists = TopoSort(func_graph->get_return());
  866. for (auto const &node : node_lists) {
  867. if (!AnfAlgo::IsRealKernel(node) || !node->isa<CNode>()) {
  868. continue;
  869. }
  870. auto cnode = node->cast<CNodePtr>();
  871. MS_EXCEPTION_IF_NULL(cnode);
  872. if (IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex))) {
  873. node_list->push_back(node);
  874. }
  875. }
  876. }
  877. void GetValidKernelNodes(const FuncGraphPtr &func_graph, std::vector<AnfNodePtr> *node_list,
  878. std::vector<AnfNodePtr> *input_list, std::vector<AnfNodePtr> *output_list) {
  879. MS_EXCEPTION_IF_NULL(node_list);
  880. MS_EXCEPTION_IF_NULL(input_list);
  881. MS_EXCEPTION_IF_NULL(output_list);
  882. MS_EXCEPTION_IF_NULL(func_graph);
  883. GetValidKernelNodes(func_graph, node_list);
  884. auto parameters = func_graph->parameters();
  885. input_list->insert(input_list->begin(), parameters.begin(), parameters.end());
  886. auto func_output = func_graph->output();
  887. MS_EXCEPTION_IF_NULL(func_output);
  888. if (func_output->isa<CNode>()) {
  889. // multi output.
  890. auto cnode = func_output->cast<CNodePtr>();
  891. MS_EXCEPTION_IF_NULL(cnode);
  892. auto input0 = cnode->input(kAnfPrimitiveIndex);
  893. MS_EXCEPTION_IF_NULL(input0);
  894. if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
  895. for (size_t input_idx = 1; input_idx < cnode->inputs().size(); ++input_idx) {
  896. auto input_node = cnode->input(input_idx);
  897. MS_EXCEPTION_IF_NULL(input_node);
  898. output_list->push_back(AnfAlgo::VisitKernel(input_node, 0).first);
  899. }
  900. } else {
  901. // single output.
  902. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  903. }
  904. } else {
  905. // single output.
  906. output_list->push_back(AnfAlgo::VisitKernel(func_output, 0).first);
  907. }
  908. }
  909. bool GetInputTensorValue(const AnfNodePtr &anf_node, size_t input_idx, nlohmann::json *const node_json) {
  910. MS_EXCEPTION_IF_NULL(anf_node);
  911. MS_EXCEPTION_IF_NULL(node_json);
  912. auto cnode = anf_node->cast<CNodePtr>();
  913. MS_EXCEPTION_IF_NULL(cnode);
  914. if (input_idx + 1 >= cnode->size()) {
  915. MS_EXCEPTION(ArgumentError) << "input_idx [" << input_idx << "] is out of index of inputs of ["
  916. << cnode->inputs().size() << "][" << cnode->DebugString() << "]";
  917. }
  918. auto input_node = cnode->input(input_idx + 1);
  919. if (!IsValueNode<tensor::Tensor>(input_node)) {
  920. return false;
  921. }
  922. auto tensor = GetValueNode<tensor::TensorPtr>(input_node);
  923. if (tensor == nullptr) {
  924. return false;
  925. }
  926. auto type_id = tensor->data_type();
  927. auto *data = tensor->data_c();
  928. MS_EXCEPTION_IF_NULL(data);
  929. if (tensor->DataDim() > 1 || tensor->DataSize() != 1) {
  930. // not const tensor.
  931. MS_LOG(WARNING) << "We take first value of tensor whose datasize != 1, [" << input_node->DebugString(2) << "]";
  932. }
  933. if (type_id == kFloat32->type_id()) {
  934. float *val = static_cast<float *>(data);
  935. MS_EXCEPTION_IF_NULL(val);
  936. (*node_json)["value"] = val[0];
  937. MS_LOG(DEBUG) << "Value of tensor[" << cnode->DebugString() << "] is [float32][" << *val << "].";
  938. return true;
  939. } else if (type_id == kFloat16->type_id()) {
  940. float16 *val = static_cast<float16 *>(data);
  941. MS_EXCEPTION_IF_NULL(val);
  942. (*node_json)["value"] = static_cast<float>(val[0]);
  943. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [float16][" << *val << "].";
  944. return true;
  945. } else if (type_id == kInt32->type_id()) {
  946. int *val = static_cast<int *>(data);
  947. MS_EXCEPTION_IF_NULL(val);
  948. (*node_json)["value"] = val[0];
  949. MS_LOG(INFO) << "Value of tensor[" << cnode->DebugString() << "] is [int32][" << *val << "].";
  950. return true;
  951. }
  952. MS_LOG(ERROR) << "Unknown value type of tensor[" << cnode->DebugString() << "]";
  953. return false;
  954. }
  955. void GetGraphRealOutput(const FuncGraphPtr &func_graph, std::vector<std::pair<AnfNodePtr, size_t>> *node_list) {
  956. MS_EXCEPTION_IF_NULL(func_graph);
  957. MS_EXCEPTION_IF_NULL(node_list);
  958. auto output = func_graph->output();
  959. MS_EXCEPTION_IF_NULL(output);
  960. if (AnfAlgo::IsRealKernel(output)) {
  961. // single output.
  962. node_list->push_back(std::make_pair(output, 0));
  963. return;
  964. } else if (IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
  965. auto output_cnode = output->cast<CNodePtr>();
  966. MS_EXCEPTION_IF_NULL(output_cnode);
  967. // multi output.
  968. auto &inputs = output_cnode->inputs();
  969. for (size_t i = 1; i < inputs.size(); ++i) {
  970. auto in_with_idx = AnfAlgo::VisitKernel(inputs[i], 0);
  971. node_list->push_back(in_with_idx);
  972. }
  973. return;
  974. }
  975. MS_EXCEPTION(ArgumentError) << "Unknown output type: " << output->DebugString(2)
  976. << " of graph: " << func_graph->ToString();
  977. }
  978. bool IsWeightBoundary(const AnfNodePtr &node) {
  979. if (node->isa<ValueNode>()) {
  980. return true;
  981. }
  982. if (node->isa<Parameter>() && AnfAlgo::IsParameterWeight(node->cast<ParameterPtr>())) {
  983. return true;
  984. }
  985. return false;
  986. }
  987. void MultiThreadCompute(const MultiThreadComputeFunc &func, MultiThreadComputeParams *params,
  988. size_t total_compute_size) {
  989. const size_t kThreadNum = 24;
  990. std::vector<std::thread> threads;
  991. threads.reserve(kThreadNum);
  992. size_t start = 0;
  993. size_t once_compute_size = (total_compute_size + kThreadNum - 1) / kThreadNum;
  994. while (start < total_compute_size) {
  995. size_t end = (start + once_compute_size) > total_compute_size ? total_compute_size : (start + once_compute_size);
  996. threads.emplace_back(std::thread(func, params, start, end));
  997. start += once_compute_size;
  998. }
  999. for (size_t i = 0; i < threads.size(); ++i) {
  1000. threads[i].join();
  1001. }
  1002. }
  1003. std::vector<int> GetReduceAttrAxis(const CNodePtr &cnode) {
  1004. if (AnfAlgo::GetInputTensorNum(cnode) != AnfAlgo::GetOutputTensorNum(cnode) &&
  1005. AnfAlgo::GetInputTensorNum(cnode) != 1) {
  1006. MS_LOG(EXCEPTION) << "the kind of reduce node [" << cnode->DebugString()
  1007. << "] is not single input or single output ";
  1008. }
  1009. std::vector<int> axis;
  1010. auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, 0);
  1011. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  1012. MS_EXCEPTION_IF_NULL(primitive);
  1013. auto axis_attr = primitive->GetAttr(kAxis);
  1014. if (axis_attr == nullptr) {
  1015. MS_LOG(ERROR) << "This node does't have axie attr.";
  1016. return std::vector<int>();
  1017. }
  1018. auto type = axis_attr->type();
  1019. MS_EXCEPTION_IF_NULL(type);
  1020. std::vector<int> axis_list;
  1021. if (type->ToString() == kTypeInt32) {
  1022. axis_list.emplace_back(GetValue<int>(axis_attr));
  1023. } else {
  1024. axis_list = GetValue<std::vector<int>>(axis_attr);
  1025. }
  1026. for (const auto &elem : axis_list) {
  1027. if (elem < 0) {
  1028. axis.emplace_back(input_shape.size() + elem);
  1029. } else {
  1030. axis.emplace_back(elem);
  1031. }
  1032. }
  1033. AnfAlgo::SetNodeAttr(kAttrAxis, MakeValue(axis), cnode);
  1034. return axis;
  1035. }
  1036. } // namespace kernel
  1037. } // namespace mindspore