You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_build.cc 42 kB

5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_kernel_build.h"
  17. #include <memory>
  18. #include <map>
  19. #include <algorithm>
  20. #include "operator/ops.h"
  21. #include "parallel/ops_info/ops_utils.h"
  22. #include "session/anf_runtime_algorithm.h"
  23. #include "kernel/tbe/tbe_adapter.h"
  24. #include "kernel/tbe/tbe_python_funcs.h"
  25. #include "kernel/tbe/tbe_convert_utils.h"
  26. #include "kernel/tbe/tbe_utils.h"
  27. namespace mindspore {
  28. namespace kernel {
  29. using mindspore::kernel::tbe::TbeAdapter;
  30. using mindspore::kernel::tbe::TbeUtils;
  31. constexpr auto kFusionOpList = "op_list";
  32. constexpr auto kFusionKernelNamePrfix = "te_fusion";
  33. constexpr auto kOptional = "optional_";
  34. constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
  35. constexpr auto kPlatform = "platform";
  36. constexpr auto kPlatTBE = "TBE";
  37. constexpr auto kGenModel = "gen_model";
  38. constexpr auto kSingle = "single";
  39. constexpr auto kImplPath = "impl_path";
  40. constexpr auto kJInputs = "inputs";
  41. constexpr auto kJOutputs = "outputs";
  42. constexpr auto kJAttrs = "attrs";
  43. constexpr auto kJKernelName = "kernel_name";
  44. constexpr auto kJOpInfo = "op_info";
  45. constexpr auto kJDtype = "dtype";
  46. constexpr auto kJtype = "type";
  47. constexpr auto kJName = "name";
  48. constexpr auto kJOriShape = "ori_shape";
  49. constexpr auto kJOriFormat = "ori_format";
  50. constexpr auto kJShape = "shape";
  51. constexpr auto kJFormat = "format";
  52. constexpr auto kJValid = "valid";
  53. constexpr auto kJParamType = "param_type";
  54. constexpr auto kParamDynamic = "dynamic";
  55. constexpr auto kParamRequred = "required";
  56. constexpr auto kJDataType = "data_type";
  57. constexpr auto kJOutputIndex = "output_index";
  58. constexpr auto kJOutputDesc = "output_desc";
  59. constexpr auto kJInputDesc = "input_desc";
  60. constexpr auto kVTypeInt = "int";
  61. constexpr auto kVTypeStr = "str";
  62. constexpr auto kVTypeBool = "bool";
  63. constexpr auto kVTypeFloat = "float";
  64. constexpr auto kVTypeListInt = "listInt";
  65. constexpr auto kVTypeInt32 = "Int32";
  66. constexpr auto kVTypeListUInt64 = "listUInt64";
  67. constexpr auto kVTypeListFloat = "listFloat";
  68. constexpr auto kVTypeListListInt = "listListInt";
  69. constexpr auto kJValue = "value";
  70. constexpr auto kJDynIndex = "dyn_index";
  71. constexpr auto kJFuncName = "func_name";
  72. std::string NormalizeFullScopeName(const string &full_scope_name) {
  73. // exp:Default/ReLU-op0 -->Default_ReLU_op0
  74. string normal_ret = full_scope_name;
  75. std::replace(normal_ret.begin(), normal_ret.end(), '/', '_');
  76. std::replace(normal_ret.begin(), normal_ret.end(), '-', '_');
  77. return normal_ret;
  78. }
  79. bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
  80. nlohmann::json *kernel_json) {
  81. MS_EXCEPTION_IF_NULL(anf_node);
  82. MS_EXCEPTION_IF_NULL(kernel_json);
  83. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  84. auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
  85. MS_EXCEPTION_IF_NULL(op_info_ptr);
  86. (*kernel_json)[kPlatform] = kPlatTBE;
  87. (*kernel_json)[kGenModel] = kSingle;
  88. (*kernel_json)[kImplPath] = op_info_ptr->impl_path();
  89. nlohmann::json op_info_json;
  90. if (op_info_ptr->impl_path().empty()) {
  91. tbe::TbeAdapter::NormalizeFuncName(&op_name);
  92. } else {
  93. op_name = op_info_ptr->kernel_name();
  94. }
  95. op_info_json[kJName] = op_name;
  96. // generate inputs json
  97. nlohmann::json inputs_json;
  98. if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) {
  99. MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed";
  100. return false;
  101. }
  102. op_info_json[kJInputs] = inputs_json;
  103. // generate outputs json
  104. nlohmann::json outputs_json;
  105. if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) {
  106. MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed";
  107. return false;
  108. }
  109. op_info_json[kJOutputs] = outputs_json;
  110. // generate attrs json
  111. nlohmann::json attrs_json;
  112. (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
  113. op_info_json[kJAttrs] = attrs_json;
  114. std::string json_str = op_info_json.dump();
  115. size_t hash_id = std::hash<std::string>()(json_str);
  116. json_name_ = op_name + "_" + std::to_string(hash_id);
  117. json_info_ = json_str;
  118. if (creater_type_ == PREBUILD) {
  119. op_info_json[kJKernelName] = NormalizeFullScopeName(anf_node->fullname_with_scope());
  120. } else {
  121. op_info_json[kJKernelName] = json_name_;
  122. }
  123. (*kernel_json)[kJOpInfo] = op_info_json;
  124. if (creater_type_ == SINGLE_BUILD) {
  125. TbeUtils::SaveJsonInfo(json_name_, json_info_);
  126. }
  127. MS_LOG(INFO) << "Operate type:" << creater_type_ << ", full scope name is :" << anf_node->fullname_with_scope()
  128. << ", json info name is : " << json_name_ << ", kernel json:" << kernel_json->dump();
  129. return true;
  130. }
  131. bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_node, size_t real_input_index,
  132. bool value, const std::shared_ptr<OpIOInfo> &input_ptr,
  133. const string &op_input_name, size_t input_i,
  134. std::vector<nlohmann::json> *input_list) {
  135. MS_EXCEPTION_IF_NULL(anf_node);
  136. MS_EXCEPTION_IF_NULL(input_ptr);
  137. MS_EXCEPTION_IF_NULL(input_list);
  138. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  139. if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
  140. TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
  141. } else {
  142. auto dtype = GetDeviceInputType(anf_node, real_input_index);
  143. auto format = GetDeviceInputFormat(anf_node, real_input_index);
  144. auto shape = GetDeviceInputShape(anf_node, real_input_index);
  145. auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
  146. if (ori_shape.empty()) {
  147. ori_shape.emplace_back(1);
  148. }
  149. nlohmann::json input_desc_json;
  150. input_desc_json[kJDtype] = dtype;
  151. input_desc_json[kJName] = op_input_name + std::to_string(input_i);
  152. input_desc_json[kJOriShape] = ori_shape;
  153. input_desc_json[kJOriFormat] = kOpFormat_NCHW;
  154. input_desc_json[kJShape] = shape;
  155. input_desc_json[kJFormat] = format;
  156. input_desc_json[kJValid] = value;
  157. input_desc_json[kJParamType] = input_ptr->param_type();
  158. input_list->emplace_back(input_desc_json);
  159. }
  160. return true;
  161. }
  162. bool TbeKernelJsonCreator::GenInputList(const std::shared_ptr<AnfNode> &anf_node, size_t input_tensor_num,
  163. const std::shared_ptr<OpIOInfo> &input_ptr, size_t *real_input_index,
  164. string *op_input_name, std::vector<nlohmann::json> *input_list) {
  165. MS_EXCEPTION_IF_NULL(anf_node);
  166. MS_EXCEPTION_IF_NULL(input_ptr);
  167. MS_EXCEPTION_IF_NULL(real_input_index);
  168. MS_EXCEPTION_IF_NULL(op_input_name);
  169. MS_EXCEPTION_IF_NULL(input_list);
  170. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  171. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  172. size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node);
  173. bool value = true;
  174. for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
  175. if (*real_input_index >= real_input_num) {
  176. if (input_ptr->param_type() == "optional") {
  177. *op_input_name = input_ptr->name() + "_optional_";
  178. nlohmann::json input_desc_json;
  179. input_desc_json[kJValid] = false;
  180. input_desc_json[kJName] = *op_input_name + std::to_string(*real_input_index);
  181. input_list->emplace_back(input_desc_json);
  182. continue;
  183. }
  184. MS_LOG(ERROR) << "Input num: " << *real_input_index << " is not match op inputs";
  185. return false;
  186. }
  187. if (op_name == "BatchNorm") {
  188. if (input_ptr->name() == "mean" || input_ptr->name() == "variance") {
  189. auto attr = primitive->GetAttr("is_training");
  190. MS_EXCEPTION_IF_NULL(attr);
  191. bool is_training = GetValue<bool>(attr);
  192. MS_LOG(INFO) << "Op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training "
  193. << is_training;
  194. if (is_training) {
  195. (*real_input_index)++;
  196. break;
  197. }
  198. }
  199. }
  200. bool ret = GenInputDescJson(anf_node, *real_input_index, value, input_ptr, *op_input_name, input_i, input_list);
  201. (*real_input_index)++;
  202. if (!ret) {
  203. return false;
  204. }
  205. }
  206. return true;
  207. }
  208. bool GetInputNameAndRealNum(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<OpIOInfo> &input_ptr,
  209. size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) {
  210. MS_EXCEPTION_IF_NULL(anf_node);
  211. MS_EXCEPTION_IF_NULL(input_ptr);
  212. MS_EXCEPTION_IF_NULL(dyn_input_index);
  213. MS_EXCEPTION_IF_NULL(input_num);
  214. MS_EXCEPTION_IF_NULL(op_input_name);
  215. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  216. // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
  217. std::vector<int> dyn_input_sizes;
  218. if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
  219. dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
  220. }
  221. if (input_ptr->param_type() == kParamDynamic) {
  222. if (*dyn_input_index >= dyn_input_sizes.size()) {
  223. MS_LOG(ERROR) << "Dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
  224. return false;
  225. }
  226. *input_num = IntToSize(dyn_input_sizes[*dyn_input_index]);
  227. *op_input_name = input_ptr->name() + "_dynamic_";
  228. (*dyn_input_index)++;
  229. // if optional input is exist
  230. } else {
  231. *input_num = 1;
  232. *op_input_name = input_ptr->name() + "_";
  233. }
  234. return true;
  235. }
  236. bool TbeKernelJsonCreator::GenTbeInputsJson(const std::shared_ptr<AnfNode> &anf_node,
  237. const std::shared_ptr<OpInfo> &op_info, nlohmann::json *inputs_json) {
  238. MS_EXCEPTION_IF_NULL(anf_node);
  239. MS_EXCEPTION_IF_NULL(op_info);
  240. MS_EXCEPTION_IF_NULL(inputs_json);
  241. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  242. if (op_name == kAtomicAddrCleanOpName) {
  243. return true;
  244. }
  245. std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
  246. if (inputs_ptr.empty()) {
  247. MS_LOG(INFO) << "Apply kernel " << op_name << "registration info has no input info";
  248. return true;
  249. }
  250. auto op_info_input_num = inputs_ptr.size();
  251. size_t dyn_input_index = 0;
  252. size_t real_input_index = 0;
  253. std::vector<std::vector<nlohmann::json>> inputs_list;
  254. for (size_t i = 0; i < op_info_input_num; i++) {
  255. size_t input_tensor_num;
  256. std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
  257. std::string op_input_name;
  258. MS_EXCEPTION_IF_NULL(input_ptr);
  259. if (!GetInputNameAndRealNum(anf_node, input_ptr, &dyn_input_index, &input_tensor_num, &op_input_name)) {
  260. return false;
  261. }
  262. std::vector<nlohmann::json> input_list;
  263. if (!GenInputList(anf_node, input_tensor_num, input_ptr, &real_input_index, &op_input_name, &input_list)) {
  264. return false;
  265. }
  266. inputs_list.emplace_back(input_list);
  267. }
  268. TbeAdapter::InputOrderPass(op_name, inputs_list, inputs_json);
  269. return true;
  270. }
  271. bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr<AnfNode> &anf_node,
  272. const std::shared_ptr<OpInfo> &op_info, nlohmann::json *outputs_json) {
  273. MS_EXCEPTION_IF_NULL(anf_node);
  274. MS_EXCEPTION_IF_NULL(op_info);
  275. MS_EXCEPTION_IF_NULL(outputs_json);
  276. auto op_name = AnfAlgo::GetCNodeName(anf_node);
  277. if (op_name == kAtomicAddrCleanOpName) {
  278. return true;
  279. }
  280. auto outputs_ptr = op_info->outputs_ptr();
  281. return GenOutputDescJson(anf_node, outputs_ptr, outputs_json);
  282. }
  283. bool TbeKernelJsonCreator::GenOutputDescJson(
  284. const std::shared_ptr<mindspore::AnfNode> &anf_node,
  285. const std::vector<std::shared_ptr<mindspore::kernel::OpIOInfo>> &outputs_ptr, nlohmann::json *outputs_json) {
  286. MS_EXCEPTION_IF_NULL(outputs_json);
  287. size_t output_idx = 0;
  288. auto op_name = AnfAlgo::GetCNodeName(anf_node);
  289. size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
  290. for (const auto &output_ptr : outputs_ptr) {
  291. size_t output_obj_num = 0;
  292. if (output_ptr->param_type() == kParamRequred) {
  293. output_obj_num = 1;
  294. } else if (output_ptr->param_type() == kParamDynamic) {
  295. if (outputs_ptr.size() > 1) {
  296. MS_LOG(ERROR) << "Dynamic output is unsupported multi output!";
  297. return false;
  298. }
  299. output_obj_num = real_output_num;
  300. } else {
  301. if (output_idx >= real_output_num) {
  302. MS_LOG(INFO) << "Op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none.";
  303. std::vector<nlohmann::json> output_list;
  304. nlohmann::json output_obj;
  305. output_obj[kJName] = output_ptr->name();
  306. output_obj[kJValid] = false;
  307. output_list.emplace_back(output_obj);
  308. (*outputs_json).push_back(output_list);
  309. continue;
  310. } else {
  311. output_obj_num = 1;
  312. }
  313. }
  314. std::vector<nlohmann::json> output_list;
  315. GenOutputList(anf_node, output_obj_num, output_ptr, &output_idx, &output_list);
  316. (*outputs_json).push_back(output_list);
  317. }
  318. return true;
  319. }
  320. void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
  321. const std::shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
  322. std::vector<nlohmann::json> *output_list) {
  323. MS_EXCEPTION_IF_NULL(output_idx);
  324. MS_EXCEPTION_IF_NULL(output_list);
  325. for (size_t i = 0; i < output_obj_num; i++) {
  326. auto dtype = GetDeviceOutputType(anf_node, *output_idx);
  327. auto format = GetDeviceOutputFormat(anf_node, *output_idx);
  328. auto shape = GetDeviceOutputShape(anf_node, *output_idx);
  329. std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
  330. if (ori_shape.empty()) {
  331. ori_shape.emplace_back(1);
  332. }
  333. nlohmann::json output_obj;
  334. output_obj[kJDtype] = dtype;
  335. output_obj[kJShape] = shape;
  336. output_obj[kJFormat] = format;
  337. output_obj[kJOriShape] = ori_shape;
  338. output_obj[kJOriFormat] = kOpFormat_NCHW;
  339. output_obj[kJName] = output_ptr->name();
  340. output_obj[kJValid] = true;
  341. output_obj[kJParamType] = output_ptr->param_type();
  342. output_list->emplace_back(output_obj);
  343. (*output_idx)++;
  344. }
  345. }
  346. bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node,
  347. const std::shared_ptr<OpInfo> &op_info, nlohmann::json *attrs_json) {
  348. MS_EXCEPTION_IF_NULL(anf_node);
  349. MS_EXCEPTION_IF_NULL(op_info);
  350. MS_EXCEPTION_IF_NULL(attrs_json);
  351. auto attrs_ptr = op_info->attrs_ptr();
  352. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  353. if (TbeAdapter::RunAttrPass(anf_node, attrs_ptr, attrs_json)) {
  354. return true;
  355. }
  356. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  357. MS_EXCEPTION_IF_NULL(primitive);
  358. for (const auto &attr_ptr : attrs_ptr) {
  359. std::string attr_name = attr_ptr->name();
  360. nlohmann::json attr_obj;
  361. attr_obj[kJName] = attr_name;
  362. if (op_name == parallel::LAYER_NORM && attr_obj[kJName] == "epsilon" && creater_type_ == OP_SELECT_FORMAT) {
  363. continue;
  364. }
  365. if (primitive->GetAttr(attr_name) != nullptr) {
  366. auto value = primitive->GetAttr(attr_name);
  367. std::string type = attr_ptr->type();
  368. ParseAttrValue(type, value, &attr_obj);
  369. attr_obj[kJValid] = true;
  370. } else {
  371. if (op_info->impl_path().empty()) {
  372. attr_obj[kJValid] = false;
  373. } else {
  374. if (attr_ptr->param_type() == kParamRequred && creater_type_ == SINGLE_BUILD) {
  375. MS_LOG(EXCEPTION) << "Op name: " << op_info->op_name() << " attr: " << attr_name
  376. << " is required, but not set.";
  377. } else {
  378. attr_obj[kJValid] = false;
  379. }
  380. }
  381. }
  382. (*attrs_json).push_back(attr_obj);
  383. }
  384. return true;
  385. }
  386. void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value,
  387. nlohmann::json *attr_obj) {
  388. MS_EXCEPTION_IF_NULL(value);
  389. MS_EXCEPTION_IF_NULL(attr_obj);
  390. if (type == kVTypeInt) {
  391. auto attr_value = GetValue<int>(value);
  392. (*attr_obj)[kJValue] = attr_value;
  393. } else if (type == kVTypeStr) {
  394. auto attr_value = GetValue<std::string>(value);
  395. if (attr_value == kOpFormat_FRAC_Z) {
  396. attr_value = kOpFormat_FRACTAL_Z;
  397. }
  398. (*attr_obj)[kJValue] = attr_value;
  399. } else if (type == kVTypeBool) {
  400. auto attr_value = GetValue<bool>(value);
  401. (*attr_obj)[kJValue] = attr_value;
  402. } else if (type == kVTypeFloat) {
  403. auto attr_value = GetValue<float>(value);
  404. (*attr_obj)[kJValue] = attr_value;
  405. } else if (type == kVTypeListInt) {
  406. std::vector<int> attr_value;
  407. auto value_type = value->type();
  408. MS_EXCEPTION_IF_NULL(value_type);
  409. auto value_type_str = value_type->ToString();
  410. if (value_type_str == kVTypeInt32) {
  411. int data = GetValue<int>(value);
  412. attr_value.push_back(data);
  413. } else {
  414. attr_value = GetValue<std::vector<int>>(value);
  415. }
  416. (*attr_obj)[kJValue] = attr_value;
  417. } else if (type == kVTypeListFloat) {
  418. std::vector<float> attr_value;
  419. auto value_type = value->type();
  420. MS_EXCEPTION_IF_NULL(value_type);
  421. auto value_type_str = value_type->ToString();
  422. if (value_type_str == kVTypeFloat) {
  423. auto data = GetValue<float>(value);
  424. attr_value.push_back(data);
  425. } else {
  426. attr_value = GetValue<std::vector<float>>(value);
  427. }
  428. (*attr_obj)[kJValue] = attr_value;
  429. } else if (type == kVTypeListUInt64) {
  430. auto attr_value = GetValue<std::vector<size_t>>(value);
  431. (*attr_obj)[kJValue] = attr_value;
  432. } else if (type == kVTypeListListInt) {
  433. auto attr_value = GetValue<std::vector<std::vector<int>>>(value);
  434. (*attr_obj)[kJValue] = attr_value;
  435. } else {
  436. MS_LOG(EXCEPTION) << "Type: " << type << "not support";
  437. }
  438. }
  439. std::vector<size_t> TbeKernelJsonCreator::GetDeviceInputShape(const AnfNodePtr &anf_node, size_t real_index) const {
  440. MS_EXCEPTION_IF_NULL(anf_node);
  441. std::vector<size_t> shape;
  442. if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
  443. shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_index);
  444. } else {
  445. shape = AnfAlgo::GetInputDeviceShape(anf_node, real_index);
  446. }
  447. if (shape.empty()) {
  448. shape.emplace_back(1);
  449. }
  450. return shape;
  451. }
  452. std::string TbeKernelJsonCreator::GetDeviceInputType(const AnfNodePtr &anf_node, size_t real_index) const {
  453. MS_EXCEPTION_IF_NULL(anf_node);
  454. TypeId type_id;
  455. if (creater_type_ == OP_SELECT_FORMAT) {
  456. type_id = AnfAlgo::GetPrevNodeOutputInferDataType(anf_node, real_index);
  457. } else {
  458. type_id = AnfAlgo::GetInputDeviceDataType(anf_node, real_index);
  459. }
  460. return tbe::TypeIdToString(type_id);
  461. }
  462. std::string TbeKernelJsonCreator::GetDeviceInputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
  463. MS_EXCEPTION_IF_NULL(anf_node);
  464. std::string format = kOpFormat_NCHW;
  465. if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
  466. format = AnfAlgo::GetInputFormat(anf_node, real_index);
  467. if (format == kOpFormat_FRAC_Z) {
  468. format = kOpFormat_FRACTAL_Z;
  469. } else if (format == kOpFormat_DEFAULT) {
  470. format = kOpFormat_NCHW;
  471. }
  472. }
  473. return format;
  474. }
  475. std::vector<size_t> TbeKernelJsonCreator::GetDeviceOutputShape(const AnfNodePtr &anf_node, size_t real_index) const {
  476. MS_EXCEPTION_IF_NULL(anf_node);
  477. std::vector<size_t> shape;
  478. if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
  479. shape = AnfAlgo::GetOutputInferShape(anf_node, real_index);
  480. } else {
  481. shape = AnfAlgo::GetOutputDeviceShape(anf_node, real_index);
  482. }
  483. if (shape.empty()) {
  484. shape.emplace_back(1);
  485. }
  486. return shape;
  487. }
  488. std::string TbeKernelJsonCreator::GetDeviceOutputType(const AnfNodePtr &anf_node, size_t real_index) const {
  489. MS_EXCEPTION_IF_NULL(anf_node);
  490. TypeId type_id;
  491. if (creater_type_ == OP_SELECT_FORMAT) {
  492. type_id = AnfAlgo::GetOutputInferDataType(anf_node, real_index);
  493. } else {
  494. type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, real_index);
  495. }
  496. return tbe::TypeIdToString(type_id);
  497. }
  498. std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_node, size_t real_index) const {
  499. MS_EXCEPTION_IF_NULL(anf_node);
  500. std::string format = kOpFormat_NCHW;
  501. if (creater_type_ != OP_SELECT_FORMAT && creater_type_ != CHECK_SUPPORTED) {
  502. format = AnfAlgo::GetOutputFormat(anf_node, real_index);
  503. if (format == kOpFormat_FRAC_Z) {
  504. format = kOpFormat_FRACTAL_Z;
  505. } else if (format == kOpFormat_DEFAULT) {
  506. format = kOpFormat_NCHW;
  507. }
  508. }
  509. return format;
  510. }
  511. bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
  512. std::vector<size_t> *output_size_list) {
  513. if (input_size_list == nullptr || output_size_list == nullptr) {
  514. MS_LOG(ERROR) << "Input size or output size is nullptr";
  515. return false;
  516. }
  517. input_size_list->clear();
  518. output_size_list->clear();
  519. for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) {
  520. for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) {
  521. size_t size_i = 1;
  522. if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) {
  523. std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName];
  524. MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
  525. continue;
  526. }
  527. for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) {
  528. size_i *= static_cast<size_t>(j);
  529. }
  530. std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype];
  531. size_t nbyte = tbe::GetDtypeNbyte(dtype);
  532. size_i *= nbyte;
  533. input_size_list->push_back(size_i);
  534. }
  535. }
  536. for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) {
  537. for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) {
  538. size_t size_i = 1;
  539. if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) {
  540. std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName];
  541. MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
  542. continue;
  543. }
  544. for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) {
  545. size_i *= static_cast<size_t>(j);
  546. }
  547. std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype];
  548. size_t nbyte = tbe::GetDtypeNbyte(dtype);
  549. size_i *= nbyte;
  550. output_size_list->push_back(size_i);
  551. }
  552. }
  553. return true;
  554. }
  555. bool TbeKernelBuild::GenFusionScopeJson(const std::vector<mindspore::AnfNodePtr> &input_nodes,
  556. const std::vector<mindspore::AnfNodePtr> &compute_nodes,
  557. nlohmann::json *fusion_str, std::string *fusion_kernel) {
  558. MS_EXCEPTION_IF_NULL(fusion_str);
  559. MS_EXCEPTION_IF_NULL(fusion_kernel);
  560. // get input layer info
  561. std::vector<std::vector<mindspore::AnfNodePtr>> input_layers;
  562. std::map<const AnfNodePtr, FusionDataType> spec_data_input;
  563. if (!GetInputLayers(input_nodes, compute_nodes, &input_layers, &spec_data_input)) {
  564. return false;
  565. }
  566. // gen fusion scopre_op jsom
  567. std::vector<nlohmann::json> compute_list;
  568. (*fusion_kernel) = kFusionKernelNamePrfix;
  569. // index: fusion build option input record, next one from 0
  570. static size_t index = 0;
  571. auto layer_iter = input_layers.begin();
  572. auto compute_op_iter = compute_nodes.begin();
  573. for (; compute_op_iter != compute_nodes.end(); ++compute_op_iter, ++layer_iter) {
  574. nlohmann::json compute_op_str;
  575. (void)GenFusionComputeJson(*compute_op_iter, &layer_iter, &compute_op_str, fusion_kernel, &index);
  576. compute_list.push_back(compute_op_str);
  577. }
  578. index = 0;
  579. // gen data input json
  580. std::vector<nlohmann::json> data_list;
  581. for (const auto &layer : input_layers) {
  582. for (const auto &data_input : layer) {
  583. nlohmann::json data_str;
  584. if (!GenFusionDataInputJson(data_input, spec_data_input, &data_str, &index)) {
  585. MS_LOG(INFO) << "Fusion error: gen fusion datainput json faild.";
  586. return false;
  587. }
  588. data_list.push_back(data_str);
  589. }
  590. }
  591. index = 0;
  592. data_list.insert(data_list.end(), compute_list.begin(), compute_list.end());
  593. (*fusion_str)[kFusionOpList] = data_list;
  594. return true;
  595. }
  596. void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
  597. size_t desc_output_idx, nlohmann::json *output_desc, FusionDataType fusion_data_type) {
  598. std::string output_desc_name = anf_node->fullname_with_scope();
  599. if (node_out_idx > 0) {
  600. output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx);
  601. }
  602. (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name);
  603. auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx);
  604. (*output_desc)[kJDataType] = tbe::TypeIdToString(type_id);
  605. auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx);
  606. if (ori_shape.empty()) {
  607. ori_shape.emplace_back(1);
  608. }
  609. (*output_desc)[kJOriShape] = ori_shape;
  610. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx);
  611. if (shape.empty()) {
  612. shape.emplace_back(1);
  613. }
  614. (*output_desc)[kJShape] = shape;
  615. auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
  616. if (format == kOpFormat_DEFAULT) {
  617. format = ori_shape.size() == 4 ? kOpFormat_NCHW : kOpFormat_ND;
  618. }
  619. (*output_desc)[kJFormat] = format;
  620. (*output_desc)[kJOriFormat] = kOpFormat_NCHW;
  621. (*output_desc)[kJOutputIndex] = desc_output_idx;
  622. if (fusion_data_type == kFusionAddN && format == kOpFormat_NC1HWC0) {
  623. std::vector<size_t> spec_shape = {};
  624. spec_shape.emplace_back(shape[0]);
  625. spec_shape.emplace_back(shape[1]);
  626. spec_shape.emplace_back(shape[2] * shape[3]);
  627. spec_shape.emplace_back(shape[4]);
  628. (*output_desc)[kJShape] = spec_shape;
  629. } else if (fusion_data_type == kFusionReLUGradV2) {
  630. std::vector<size_t> spec_shape = {};
  631. spec_shape.emplace_back(shape[0]);
  632. spec_shape.emplace_back(shape[1]);
  633. spec_shape.emplace_back(shape[2] * shape[3]);
  634. spec_shape.emplace_back(16);
  635. (*output_desc)[kJShape] = spec_shape;
  636. (*output_desc)[kJDataType] = kVTypeBool;
  637. }
  638. }
  639. void TbeKernelBuild::GenReusedOutputDesc(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
  640. size_t output_index, nlohmann::json *output_desc) {
  641. std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index);
  642. (*output_desc)[kJName] = NormalizeFullScopeName(output_desc_name);
  643. (*output_desc)[kJOutputIndex] = output_index;
  644. std::vector<size_t> shape;
  645. (*output_desc)[kJShape] = shape;
  646. }
  647. bool TbeKernelBuild::GetSpecInputLayers(const std::string &op_name,
  648. const std::vector<mindspore::AnfNodePtr> &reorder_layer,
  649. std::map<const AnfNodePtr, FusionDataType> *spec_data_input) {
  650. if ((op_name == kReluGradV2OpName || op_name == kAddNOpName) && reorder_layer.empty()) {
  651. MS_LOG(INFO) << "Fusion error: node(" << op_name << " )'s input is null. ";
  652. return false;
  653. }
  654. MS_LOG(INFO) << "Fusion info: op_name: " << op_name << "input layer size: " << reorder_layer.size();
  655. if (op_name == kReluGradV2OpName) {
  656. (*spec_data_input)[reorder_layer[0]] = kFusionReLUGradV2;
  657. } else if (op_name == kAddNOpName) {
  658. for (const auto &it : reorder_layer) {
  659. (*spec_data_input)[it] = kFusionAddN;
  660. }
  661. }
  662. return true;
  663. }
  664. bool TbeKernelBuild::GetInputLayers(const std::vector<mindspore::AnfNodePtr> &input_nodes,
  665. const std::vector<mindspore::AnfNodePtr> &compute_nodes,
  666. std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers,
  667. std::map<const AnfNodePtr, FusionDataType> *spec_data_input) {
  668. MS_EXCEPTION_IF_NULL(input_layers);
  669. MS_EXCEPTION_IF_NULL(spec_data_input);
  670. auto result = std::find_if(compute_nodes.begin(), compute_nodes.end(), [](const auto &it) {
  671. auto op_name = AnfAlgo::GetCNodeName(it);
  672. return op_name == kConv2DBackpropInputOpName;
  673. });
  674. bool need_spec = (result != compute_nodes.end());
  675. size_t input_size = 0;
  676. for (const auto &compute_node : compute_nodes) {
  677. std::vector<mindspore::AnfNodePtr> layer = {};
  678. std::vector<mindspore::AnfNodePtr> reorder_layer = {};
  679. MS_EXCEPTION_IF_NULL(compute_node);
  680. auto op_name = AnfAlgo::GetCNodeName(compute_node);
  681. auto ccompute_node = compute_node->cast<CNodePtr>();
  682. if (ccompute_node == nullptr) {
  683. MS_LOG(INFO) << "Fusion error: fusion compute node must be cnode";
  684. return false;
  685. }
  686. MS_LOG(INFO) << "Fusion info: compute name: " << compute_node->fullname_with_scope();
  687. for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) {
  688. auto input = ccompute_node->input(i);
  689. auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input);
  690. if (find_iter != input_nodes.end()) {
  691. MS_LOG(INFO) << "Fusion info: add compute node's [" << i << "] input: " << input->fullname_with_scope();
  692. layer.emplace_back((*find_iter));
  693. } else {
  694. MS_LOG(INFO) << "Fusion warnig: this input [" << i << "] may be pre compute(" << input->fullname_with_scope()
  695. << ") node's output.";
  696. }
  697. }
  698. TbeAdapter::FusionDataOrderPass(op_name, layer, &reorder_layer);
  699. if (need_spec) {
  700. MS_LOG(INFO) << "Fusion info: match conv2d backprop input + ... patten.";
  701. if (!GetSpecInputLayers(op_name, reorder_layer, spec_data_input)) {
  702. return false;
  703. }
  704. }
  705. input_size += reorder_layer.size();
  706. input_layers->emplace_back(reorder_layer);
  707. }
  708. if (input_nodes.size() != input_size) {
  709. MS_LOG(INFO) << "Fusion error: fusion scope error, layer input:" << input_size
  710. << ", input_node:" << input_nodes.size();
  711. return false;
  712. }
  713. return true;
  714. }
  715. bool TbeKernelBuild::GenFusionDataInputJson(const std::shared_ptr<mindspore::AnfNode> &data_input,
  716. const std::map<const AnfNodePtr, FusionDataType> &spec_data_input,
  717. nlohmann::json *data_str, size_t *index) {
  718. MS_EXCEPTION_IF_NULL(data_str);
  719. MS_EXCEPTION_IF_NULL(index);
  720. std::vector<nlohmann::json> output_desc_list;
  721. if (!data_input) {
  722. MS_LOG(INFO) << "Data input is optional node";
  723. auto name = std::string(kOptional) + std::to_string(*index);
  724. (*data_str)[kJName] = name;
  725. nlohmann::json output_desc;
  726. output_desc[kJName] = name;
  727. output_desc[kJShape] = "NULL";
  728. output_desc_list.push_back(output_desc);
  729. (*index)++;
  730. } else {
  731. FusionDataType fusion_data_type = kFusionNormal;
  732. if (spec_data_input.find(data_input) != spec_data_input.end()) {
  733. fusion_data_type = spec_data_input.at(data_input);
  734. }
  735. auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0);
  736. auto real_node = kernel_idx.first;
  737. size_t real_idx = kernel_idx.second;
  738. MS_LOG(INFO) << "Real name " << real_node->fullname_with_scope() << " index:" << real_idx;
  739. // kJOutputDesc
  740. nlohmann::json output_desc;
  741. GenDescJson(real_node, real_idx, real_idx, &output_desc, fusion_data_type);
  742. output_desc_list.push_back(output_desc);
  743. (*data_str)[kJName] = NormalizeFullScopeName(real_node->fullname_with_scope());
  744. }
  745. (*data_str)[kJOutputDesc] = output_desc_list;
  746. (*data_str)[kJtype] = "Data";
  747. return true;
  748. }
  749. bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) {
  750. MS_EXCEPTION_IF_NULL(cnode);
  751. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  752. MS_EXCEPTION_IF_NULL(primitive);
  753. // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
  754. bool ret = false;
  755. std::vector<int> dyn_input_sizes;
  756. auto dynamic_input_attr = primitive->GetAttr(kAttrDynInputSizes);
  757. if (dynamic_input_attr != nullptr) {
  758. dyn_input_sizes = GetValue<const std::vector<int>>(dynamic_input_attr);
  759. auto real_input_size = cnode->inputs().size() - 1;
  760. auto dyn_input_size = dyn_input_sizes.size();
  761. if (dyn_input_size != 1) {
  762. MS_LOG(INFO) << "Fusion error: fusion build not support dyn_input_sizes > 1";
  763. return ret;
  764. }
  765. if (IntToSize(dyn_input_sizes[0]) != real_input_size) {
  766. MS_LOG(INFO) << "Fusion error: dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size"
  767. << real_input_size;
  768. return ret;
  769. }
  770. ret = true;
  771. }
  772. return ret;
  773. }
  774. size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) {
  775. MS_EXCEPTION_IF_NULL(cnode);
  776. if (is_dynamic_input) {
  777. return 0;
  778. }
  779. MS_EXCEPTION_IF_NULL(cnode);
  780. auto node_name = AnfAlgo::GetCNodeName(cnode);
  781. auto op_info = OpLib::FindOp(node_name, kTBE);
  782. MS_EXCEPTION_IF_NULL(cnode);
  783. if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
  784. MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
  785. }
  786. return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
  787. }
  788. std::string TbeKernelBuild::GetRealOpType(const std::string &origin_type) {
  789. static std::map<std::string, std::string> buffer_fussion_op_map = {
  790. {parallel::DEPTHWISE_CONV2D_NATIVE, parallel::DEPTHWISE_CONV2D}, {parallel::TENSOR_ADD, parallel::ADD}};
  791. string result = origin_type;
  792. auto iter = buffer_fussion_op_map.find(origin_type);
  793. if (iter != buffer_fussion_op_map.end()) {
  794. result = iter->second;
  795. }
  796. return result;
  797. }
  798. bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
  799. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  800. std::vector<nlohmann::json> *input_desc_list, size_t *index) {
  801. MS_EXCEPTION_IF_NULL(cnode);
  802. MS_EXCEPTION_IF_NULL(input_desc_list);
  803. std::vector<nlohmann::json> input_desc_list_tmp = {};
  804. bool is_dynamic_input = IsDynamicInput(cnode);
  805. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  806. auto input = cnode->input(i);
  807. auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
  808. auto real_node = kernel_idx.first;
  809. size_t real_idx = kernel_idx.second;
  810. MS_LOG(INFO) << "Real name" << real_node->fullname_with_scope() << "index:" << real_idx;
  811. nlohmann::json input_desc;
  812. GenDescJson(real_node, real_idx, real_idx, &input_desc);
  813. if (is_dynamic_input) {
  814. MS_LOG(INFO) << "Node has dynamic input.";
  815. input_desc[kJDynIndex] = (i - 1);
  816. }
  817. input_desc_list_tmp.emplace_back(input_desc);
  818. }
  819. size_t optional_num = GetOptionalInput(cnode, is_dynamic_input);
  820. if (optional_num > 0) {
  821. MS_LOG(INFO) << "Node has optional input.";
  822. for (size_t i = 0; i < optional_num; ++i) {
  823. nlohmann::json optional_input_desc;
  824. optional_input_desc[kJName] = std::string(kOptional) + std::to_string(*index);
  825. (*index)++;
  826. (*layer_iter)->emplace_back(nullptr);
  827. input_desc_list_tmp.emplace_back(optional_input_desc);
  828. }
  829. }
  830. auto op_name = AnfAlgo::GetCNodeName(cnode);
  831. TbeAdapter::FusionInputOrderPass(op_name, input_desc_list_tmp, input_desc_list);
  832. return true;
  833. }
  834. std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int> &output_used_nums) {
  835. std::vector<size_t> desc_output_index = {};
  836. for (size_t idx = 0; idx < output_used_nums.size(); ++idx) {
  837. auto output_use_num_item = output_used_nums[idx];
  838. MS_LOG(INFO) << "Output used num[" << idx << "] = " << output_use_num_item;
  839. desc_output_index.emplace_back(idx);
  840. if (output_use_num_item > 1) {
  841. desc_output_index.emplace_back(idx);
  842. }
  843. }
  844. return desc_output_index;
  845. }
  846. bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode,
  847. std::vector<nlohmann::json> *output_desc_list) {
  848. MS_EXCEPTION_IF_NULL(output_desc_list);
  849. auto output_size = AnfAlgo::GetOutputTensorNum(cnode);
  850. if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) {
  851. auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrOutputUsedNum);
  852. MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope();
  853. if (output_used_nums.size() != output_size) {
  854. MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")"
  855. << " is not match output used num(" << output_used_nums.size() << ")";
  856. return false;
  857. }
  858. auto desc_output_index = GetDescOutputIndex(output_used_nums);
  859. for (size_t i = 0; i < output_size; ++i) {
  860. MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i];
  861. nlohmann::json output_desc;
  862. GenDescJson(cnode, i, desc_output_index[i], &output_desc);
  863. output_desc_list->emplace_back(output_desc);
  864. }
  865. for (size_t j = output_size; j < desc_output_index.size(); ++j) {
  866. MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j];
  867. nlohmann::json output_desc;
  868. GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc);
  869. output_desc_list->emplace_back(output_desc);
  870. }
  871. } else {
  872. for (size_t i = 0; i < output_size; ++i) {
  873. nlohmann::json output_desc;
  874. GenDescJson(cnode, i, i, &output_desc);
  875. output_desc_list->push_back(output_desc);
  876. }
  877. }
  878. return true;
  879. }
  880. bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node,
  881. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  882. nlohmann::json *compute_op_str, std::string *fusion_kernel_name,
  883. size_t *index) {
  884. MS_EXCEPTION_IF_NULL(compute_node);
  885. auto cnode = compute_node->cast<CNodePtr>();
  886. MS_EXCEPTION_IF_NULL(cnode);
  887. // gen input desc
  888. std::vector<nlohmann::json> input_desc_list;
  889. (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index);
  890. (*compute_op_str)[kJInputDesc] = input_desc_list;
  891. // gen output desc
  892. std::vector<nlohmann::json> output_desc_list;
  893. if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) {
  894. MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope();
  895. return false;
  896. }
  897. (*compute_op_str)[kJOutputDesc] = output_desc_list;
  898. // gen others
  899. auto origin_type = AnfAlgo::GetCNodeName(cnode);
  900. // replace special op type for buffer fusion op
  901. auto type = GetRealOpType(origin_type);
  902. (*compute_op_str)[kJtype] = type;
  903. tbe::TbeAdapter::NormalizeFuncName(&type);
  904. (*compute_op_str)[kJFuncName] = type;
  905. (*compute_op_str)[kJName] = NormalizeFullScopeName(cnode->fullname_with_scope());
  906. (void)(*fusion_kernel_name).append("_");
  907. (void)(*fusion_kernel_name).append(type);
  908. return true;
  909. }
  910. size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) {
  911. size_t ret = 1;
  912. for (const auto &shape_item : desc[kJShape]) {
  913. ret *= static_cast<size_t>(shape_item);
  914. }
  915. std::string data_type = desc[kJDataType];
  916. size_t nbyte = tbe::GetDtypeNbyte(data_type);
  917. ret *= nbyte;
  918. return ret;
  919. }
  920. bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list,
  921. const std::vector<mindspore::AnfNodePtr> &output_nodes,
  922. std::vector<size_t> *input_size_list, std::vector<size_t> *output_size_list) {
  923. MS_EXCEPTION_IF_NULL(input_size_list);
  924. MS_EXCEPTION_IF_NULL(output_size_list);
  925. input_size_list->clear();
  926. output_size_list->clear();
  927. for (const auto &op : fusion_op_list) {
  928. if (op[kJtype] == "Data") {
  929. const auto &data_output_desc = op[kJOutputDesc];
  930. for (const auto &data_output : data_output_desc) {
  931. if (data_output[kJShape] == "NULL") {
  932. break;
  933. }
  934. auto ret = GetIOSizeImpl(data_output);
  935. input_size_list->push_back(ret);
  936. MS_LOG(INFO) << "Fusion info: scope input name: " << op[kJName] << ", size: " << ret;
  937. }
  938. }
  939. }
  940. for (const auto &output_node : output_nodes) {
  941. auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0);
  942. auto real_node = kernel_idx.first;
  943. size_t real_idx = kernel_idx.second;
  944. auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope());
  945. MS_LOG(INFO) << "Fusion info: real node name: " << normal_name << ", real output index: " << real_idx;
  946. for (const auto &op : fusion_op_list) {
  947. if (op[kJName] == normal_name) {
  948. auto op_output_desces = op[kJOutputDesc];
  949. if (output_node != real_node) {
  950. // tuple_get item
  951. MS_LOG(INFO) << "Output is a tuple getitem node";
  952. auto output_desc = op_output_desces[real_idx];
  953. if (output_desc[kJShape].empty()) {
  954. MS_LOG(INFO) << "Fusion error: output_desc's shape is empty. real_index " << real_idx;
  955. return false;
  956. }
  957. auto ret = GetIOSizeImpl(output_desc);
  958. output_size_list->push_back(ret);
  959. MS_LOG(INFO) << "Fusion info: scope output index: " << real_idx << ", size: " << ret;
  960. } else {
  961. for (const auto &output_desc : op_output_desces) {
  962. if (output_desc[kJShape].empty()) {
  963. MS_LOG(INFO) << "Fusion info: output_desc's shape is empty, may be this node output";
  964. continue;
  965. }
  966. auto ret = GetIOSizeImpl(output_desc);
  967. output_size_list->push_back(ret);
  968. MS_LOG(INFO) << "Fusion info: scope output size: " << ret;
  969. }
  970. }
  971. }
  972. }
  973. }
  974. return true;
  975. }
  976. } // namespace kernel
  977. } // namespace mindspore