You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_kernel_build.cc 34 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_kernel_build.h"
  17. #include <memory>
  18. #include <map>
  19. #include <algorithm>
  20. #include <unordered_set>
  21. #include "operator/ops.h"
  22. #include "session/anf_runtime_algorithm.h"
  23. #include "kernel/tbe/tbe_kernel_mod.h"
  24. #include "kernel/tbe/tbe_adapter.h"
  25. #include "kernel/tbe/tbe_python_funcs.h"
  26. #include "kernel/tbe/tbe_convert_utils.h"
  27. #include "kernel/tbe/tbe_utils.h"
  28. namespace mindspore {
  29. namespace kernel {
  30. using mindspore::kernel::tbe::TbeAdapter;
  31. using mindspore::kernel::tbe::TbeUtils;
  32. constexpr auto kFusionOpList = "op_list";
  33. constexpr auto kFusionKernelNamePrfix = "te_fusion";
  34. constexpr auto kOptional = "optional_";
  35. constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";
  36. std::string NormalizeFullScopeName(const string &full_scope_name) {
  37. // exp:Default/ReLU-op0 -->Default_ReLU_op0
  38. string normal_ret = full_scope_name;
  39. std::replace(normal_ret.begin(), normal_ret.end(), '/', '_');
  40. std::replace(normal_ret.begin(), normal_ret.end(), '-', '_');
  41. return normal_ret;
  42. }
  43. bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const shared_ptr<mindspore::AnfNode> &anf_node,
  44. nlohmann::json *kernel_json) {
  45. MS_EXCEPTION_IF_NULL(anf_node);
  46. MS_EXCEPTION_IF_NULL(kernel_json);
  47. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  48. auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
  49. MS_EXCEPTION_IF_NULL(op_info_ptr);
  50. (*kernel_json)["platform"] = "TBE";
  51. (*kernel_json)["gen_model"] = "single";
  52. (*kernel_json)["impl_path"] = op_info_ptr->impl_path();
  53. nlohmann::json op_info_json;
  54. if (op_info_ptr->impl_path().empty()) {
  55. tbe::TbeAdapter::NormalizeFuncName(&op_name);
  56. } else {
  57. op_name = op_info_ptr->kernel_name();
  58. }
  59. op_info_json["name"] = op_name;
  60. // generate inputs json
  61. nlohmann::json inputs_json;
  62. if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) {
  63. MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate inputs json failed";
  64. return false;
  65. }
  66. op_info_json["inputs"] = inputs_json;
  67. // generate outputs json
  68. nlohmann::json outputs_json;
  69. if (!GenTbeOutputsJson(anf_node, op_info_ptr, &outputs_json)) {
  70. MS_LOG(ERROR) << "Anf Node [" << op_name << "] generate outputs json failed";
  71. return false;
  72. }
  73. op_info_json["outputs"] = outputs_json;
  74. // generate attrs json
  75. nlohmann::json attrs_json;
  76. (void)GenTbeAttrJson(anf_node, op_info_ptr, &attrs_json);
  77. op_info_json["attrs"] = attrs_json;
  78. std::string json_str = op_info_json.dump();
  79. size_t hash_id = std::hash<std::string>()(json_str);
  80. json_name_ = op_name + "_" + std::to_string(hash_id);
  81. json_info_ = json_str;
  82. if (creater_type_ == PREBUILD) {
  83. op_info_json["kernel_name"] = NormalizeFullScopeName(anf_node->fullname_with_scope());
  84. } else {
  85. op_info_json["kernel_name"] = json_name_;
  86. }
  87. (*kernel_json)["op_info"] = op_info_json;
  88. if (creater_type_ == SINGLE_BUILD) {
  89. TbeUtils::SaveJsonInfo(json_name_, json_info_);
  90. }
  91. MS_LOG(INFO) << "Operate type:" << creater_type_ << ", full scope name is :" << anf_node->fullname_with_scope()
  92. << ", json info name is : " << json_name_ << ", kernel json:" << kernel_json->dump();
  93. return true;
  94. }
  95. bool TbeKernelJsonCreator::GenInputDescJson(const shared_ptr<AnfNode> &anf_node, size_t real_input_index, bool value,
  96. const shared_ptr<OpIOInfo> &input_ptr, const string &op_input_name,
  97. size_t input_i, vector<nlohmann::json> *input_list) {
  98. MS_EXCEPTION_IF_NULL(anf_node);
  99. MS_EXCEPTION_IF_NULL(input_ptr);
  100. MS_EXCEPTION_IF_NULL(input_list);
  101. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  102. if (input_ptr->name() == "input_indices" && op_name == kTopKOpName) {
  103. TbeAdapter::GenTopKV2IndicesTensorInfo(anf_node, real_input_index, input_list, creater_type_);
  104. } else {
  105. // dtype : float16
  106. auto tensor_dtype =
  107. std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, real_input_index)));
  108. MS_EXCEPTION_IF_NULL(tensor_dtype);
  109. std::string dtype = tensor_dtype->element()->ToString();
  110. dtype = tbe::DtypeToString(dtype);
  111. // format
  112. std::string format = AnfAlgo::GetInputFormat(anf_node, real_input_index);
  113. if (format == kOpFormat_DEFAULT) {
  114. format = kOpFormat_NCHW;
  115. } else if (format == kOpFormat_FRAC_Z) {
  116. format = kOpFormat_FRACTAL_Z;
  117. }
  118. nlohmann::json input_desc_json;
  119. input_desc_json["dtype"] = dtype;
  120. input_desc_json["name"] = op_input_name + std::to_string(input_i);
  121. auto ori_shape = AnfAlgo::GetPrevNodeOutputInferShape(anf_node, real_input_index);
  122. if (ori_shape.empty()) {
  123. ori_shape.emplace_back(1);
  124. }
  125. input_desc_json["ori_shape"] = ori_shape;
  126. input_desc_json["ori_format"] = kOpFormat_NCHW;
  127. auto shape = AnfAlgo::GetInputDeviceShape(anf_node, real_input_index);
  128. if (shape.empty()) {
  129. shape.emplace_back(1);
  130. }
  131. if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
  132. input_desc_json["shape"] = ori_shape;
  133. input_desc_json["format"] = kOpFormat_NCHW;
  134. } else {
  135. input_desc_json["shape"] = shape;
  136. input_desc_json["format"] = format;
  137. }
  138. input_desc_json["valid"] = value;
  139. input_list->emplace_back(input_desc_json);
  140. }
  141. return true;
  142. }
  143. bool TbeKernelJsonCreator::GenInputList(const shared_ptr<AnfNode> &anf_node, size_t input_tensor_num,
  144. const shared_ptr<OpIOInfo> &input_ptr, size_t *real_input_index,
  145. string *op_input_name, vector<nlohmann::json> *input_list) {
  146. MS_EXCEPTION_IF_NULL(anf_node);
  147. MS_EXCEPTION_IF_NULL(input_ptr);
  148. MS_EXCEPTION_IF_NULL(real_input_index);
  149. MS_EXCEPTION_IF_NULL(op_input_name);
  150. MS_EXCEPTION_IF_NULL(input_list);
  151. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  152. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  153. size_t real_input_num = AnfAlgo::GetInputTensorNum(anf_node);
  154. bool value = true;
  155. for (size_t input_i = 0; input_i < input_tensor_num; input_i++) {
  156. if (*real_input_index >= real_input_num) {
  157. if (input_ptr->param_type() == "optional") {
  158. *op_input_name = input_ptr->name() + "_optional_";
  159. nlohmann::json input_desc_json;
  160. input_desc_json["valid"] = false;
  161. input_desc_json["name"] = *op_input_name + std::to_string(*real_input_index);
  162. input_list->emplace_back(input_desc_json);
  163. continue;
  164. }
  165. MS_LOG(ERROR) << "input num" << *real_input_index << "is not match op inputs";
  166. return false;
  167. }
  168. if (op_name == "BatchNorm") {
  169. if (input_ptr->name() == "mean" || input_ptr->name() == "variance") {
  170. auto attr = primitive->GetAttr("is_training");
  171. MS_EXCEPTION_IF_NULL(attr);
  172. bool is_training = GetValue<bool>(attr);
  173. MS_LOG(INFO) << "op_name" << op_name << ", tensor_name " << input_ptr->name() << ", is_training "
  174. << is_training;
  175. if (is_training) {
  176. (*real_input_index)++;
  177. break;
  178. }
  179. }
  180. }
  181. bool ret = GenInputDescJson(anf_node, *real_input_index, value, input_ptr, *op_input_name, input_i, input_list);
  182. (*real_input_index)++;
  183. if (!ret) {
  184. return false;
  185. }
  186. }
  187. return true;
  188. }
  189. bool GetInputNameAndRealNum(const std::shared_ptr<AnfNode> &anf_node, const shared_ptr<OpIOInfo> &input_ptr,
  190. size_t *dyn_input_index, size_t *input_num, std::string *op_input_name) {
  191. MS_EXCEPTION_IF_NULL(anf_node);
  192. MS_EXCEPTION_IF_NULL(input_ptr);
  193. MS_EXCEPTION_IF_NULL(dyn_input_index);
  194. MS_EXCEPTION_IF_NULL(input_num);
  195. MS_EXCEPTION_IF_NULL(op_input_name);
  196. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  197. // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
  198. std::vector<int> dyn_input_sizes;
  199. if (primitive->GetAttr(kAttrDynInputSizes) != nullptr) {
  200. dyn_input_sizes = GetValue<const std::vector<int>>(primitive->GetAttr(kAttrDynInputSizes));
  201. }
  202. if (input_ptr->param_type() == "dynamic") {
  203. if (*dyn_input_index >= dyn_input_sizes.size()) {
  204. MS_LOG(ERROR) << "dyn input index" << *dyn_input_index << "is over dyn input num" << dyn_input_sizes.size();
  205. return false;
  206. }
  207. *input_num = IntToSize(dyn_input_sizes[*dyn_input_index]);
  208. *op_input_name = input_ptr->name() + "_dynamic_";
  209. (*dyn_input_index)++;
  210. // if optional input is exist
  211. } else {
  212. *input_num = 1;
  213. *op_input_name = input_ptr->name() + "_";
  214. }
  215. return true;
  216. }
  217. bool TbeKernelJsonCreator::GenTbeInputsJson(const std::shared_ptr<AnfNode> &anf_node,
  218. const std::shared_ptr<OpInfo> &op_info, nlohmann::json *inputs_json) {
  219. MS_EXCEPTION_IF_NULL(anf_node);
  220. MS_EXCEPTION_IF_NULL(op_info);
  221. MS_EXCEPTION_IF_NULL(inputs_json);
  222. std::string op_name = AnfAlgo::GetCNodeName(anf_node);
  223. if (op_name == kAtomicAddrCleanOpName) {
  224. return true;
  225. }
  226. std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr = op_info->inputs_ptr();
  227. if (inputs_ptr.empty()) {
  228. MS_LOG(INFO) << "Apply kernel " << op_name << "registration info has no input info";
  229. return true;
  230. }
  231. auto op_info_input_num = inputs_ptr.size();
  232. size_t dyn_input_index = 0;
  233. size_t real_input_index = 0;
  234. std::vector<std::vector<nlohmann::json>> inputs_list;
  235. for (size_t i = 0; i < op_info_input_num; i++) {
  236. size_t input_tensor_num;
  237. std::shared_ptr<OpIOInfo> input_ptr = inputs_ptr[i];
  238. std::string op_input_name;
  239. MS_EXCEPTION_IF_NULL(input_ptr);
  240. if (!GetInputNameAndRealNum(anf_node, input_ptr, &dyn_input_index, &input_tensor_num, &op_input_name)) {
  241. return false;
  242. }
  243. std::vector<nlohmann::json> input_list;
  244. if (!GenInputList(anf_node, input_tensor_num, input_ptr, &real_input_index, &op_input_name, &input_list)) {
  245. return false;
  246. }
  247. inputs_list.emplace_back(input_list);
  248. }
  249. TbeAdapter::InputOrderPass(op_name, inputs_list, inputs_json);
  250. return true;
  251. }
  252. bool TbeKernelJsonCreator::GenTbeOutputsJson(const std::shared_ptr<AnfNode> &anf_node,
  253. const std::shared_ptr<OpInfo> &op_info, nlohmann::json *outputs_json) {
  254. MS_EXCEPTION_IF_NULL(anf_node);
  255. MS_EXCEPTION_IF_NULL(op_info);
  256. MS_EXCEPTION_IF_NULL(outputs_json);
  257. auto op_name = AnfAlgo::GetCNodeName(anf_node);
  258. if (op_name == kAtomicAddrCleanOpName) {
  259. return true;
  260. }
  261. auto outputs_ptr = op_info->outputs_ptr();
  262. return GenOutputDescJson(anf_node, outputs_ptr, outputs_json);
  263. }
  264. bool TbeKernelJsonCreator::GenOutputDescJson(const shared_ptr<mindspore::AnfNode> &anf_node,
  265. const vector<shared_ptr<mindspore::kernel::OpIOInfo>> &outputs_ptr,
  266. nlohmann::json *outputs_json) {
  267. MS_EXCEPTION_IF_NULL(outputs_json);
  268. size_t output_idx = 0;
  269. auto op_name = AnfAlgo::GetCNodeName(anf_node);
  270. size_t real_output_num = AnfAlgo::GetOutputTensorNum(anf_node);
  271. for (const auto &output_ptr : outputs_ptr) {
  272. size_t output_obj_num = 0;
  273. if (output_ptr->param_type() == "required") {
  274. output_obj_num = 1;
  275. } else if (output_ptr->param_type() == "dynamic") {
  276. if (outputs_ptr.size() > 1) {
  277. MS_LOG(ERROR) << "Dynamic output is unsupported multi output!";
  278. return false;
  279. }
  280. output_obj_num = real_output_num;
  281. } else {
  282. if (output_idx >= real_output_num) {
  283. MS_LOG(INFO) << "op:" << op_name << ", output" << output_ptr->name() << " is optional, output is none.";
  284. std::vector<nlohmann::json> output_list;
  285. nlohmann::json output_obj;
  286. output_obj["name"] = output_ptr->name();
  287. output_obj["valid"] = false;
  288. output_list.emplace_back(output_obj);
  289. (*outputs_json).push_back(output_list);
  290. continue;
  291. } else {
  292. output_obj_num = 1;
  293. }
  294. }
  295. std::vector<nlohmann::json> output_list;
  296. GenOutputList(anf_node, output_obj_num, output_ptr, &output_idx, &output_list);
  297. (*outputs_json).push_back(output_list);
  298. }
  299. return true;
  300. }
  301. void TbeKernelJsonCreator::GenOutputList(const shared_ptr<AnfNode> &anf_node, const size_t &output_obj_num,
  302. const shared_ptr<OpIOInfo> &output_ptr, size_t *output_idx,
  303. vector<nlohmann::json> *output_list) {
  304. MS_EXCEPTION_IF_NULL(output_idx);
  305. MS_EXCEPTION_IF_NULL(output_list);
  306. for (size_t i = 0; i < output_obj_num; i++) {
  307. nlohmann::json output_obj;
  308. auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, *output_idx)));
  309. std::string dtype = type_ptr->element()->ToString();
  310. dtype = tbe::DtypeToString(dtype);
  311. std::string format = AnfAlgo::GetOutputFormat(anf_node, *output_idx);
  312. if (format == kOpFormat_DEFAULT) {
  313. format = kOpFormat_NCHW;
  314. } else if (format == kOpFormat_FRAC_Z) {
  315. format = kOpFormat_FRACTAL_Z;
  316. }
  317. std::vector<size_t> ori_shape;
  318. if (AnfAlgo::GetOutputInferShape(anf_node, *output_idx).empty()) {
  319. ori_shape.emplace_back(1);
  320. } else {
  321. ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
  322. }
  323. output_obj["dtype"] = dtype;
  324. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, *output_idx);
  325. if (shape.empty()) {
  326. shape.emplace_back(1);
  327. }
  328. if (creater_type_ == OP_SELECT_FORMAT || creater_type_ == CHECK_SUPPORTED) {
  329. output_obj["shape"] = ori_shape;
  330. output_obj["format"] = kOpFormat_NCHW;
  331. } else {
  332. output_obj["shape"] = shape;
  333. output_obj["format"] = format;
  334. }
  335. output_obj["ori_shape"] = ori_shape;
  336. output_obj["ori_format"] = kOpFormat_NCHW;
  337. output_obj["name"] = output_ptr->name();
  338. output_obj["valid"] = true;
  339. output_list->emplace_back(output_obj);
  340. (*output_idx)++;
  341. }
  342. }
  343. bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr<AnfNode> &anf_node,
  344. const std::shared_ptr<OpInfo> &op_info, nlohmann::json *attrs_json) {
  345. MS_EXCEPTION_IF_NULL(anf_node);
  346. MS_EXCEPTION_IF_NULL(op_info);
  347. MS_EXCEPTION_IF_NULL(attrs_json);
  348. auto attrs_ptr = op_info->attrs_ptr();
  349. if (TbeAdapter::RunAttrPass(anf_node, attrs_ptr, attrs_json)) {
  350. return true;
  351. }
  352. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  353. MS_EXCEPTION_IF_NULL(primitive);
  354. for (const auto &attr_ptr : attrs_ptr) {
  355. std::string attr_name = attr_ptr->name();
  356. if (primitive->GetAttr(attr_name) != nullptr) {
  357. nlohmann::json attr_obj;
  358. auto value = primitive->GetAttr(attr_name);
  359. std::string type = attr_ptr->type();
  360. ParseAttrValue(type, value, &attr_obj);
  361. attr_obj["name"] = attr_name;
  362. attr_obj["valid"] = true;
  363. (*attrs_json).push_back(attr_obj);
  364. } else {
  365. if (attr_ptr->param_type() == "required" && creater_type_ == SINGLE_BUILD && op_info->impl_path() != "") {
  366. MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name
  367. << " is required, but not set.";
  368. }
  369. }
  370. }
  371. return true;
  372. }
  373. void TbeKernelJsonCreator::ParseAttrValue(const std::string &type, const mindspore::ValuePtr &value,
  374. nlohmann::json *attr_obj) {
  375. MS_EXCEPTION_IF_NULL(value);
  376. MS_EXCEPTION_IF_NULL(attr_obj);
  377. if (type == "int") {
  378. auto attr_value = GetValue<int>(value);
  379. (*attr_obj)["value"] = attr_value;
  380. } else if (type == "str") {
  381. auto attr_value = GetValue<std::string>(value);
  382. if (attr_value == kOpFormat_FRAC_Z) {
  383. attr_value = kOpFormat_FRACTAL_Z;
  384. }
  385. (*attr_obj)["value"] = attr_value;
  386. } else if (type == "bool") {
  387. auto attr_value = GetValue<bool>(value);
  388. (*attr_obj)["value"] = attr_value;
  389. } else if (type == "float") {
  390. auto attr_value = GetValue<float>(value);
  391. (*attr_obj)["value"] = attr_value;
  392. } else if (type == "listInt") {
  393. std::vector<int> attr_value;
  394. auto value_type = value->type();
  395. MS_EXCEPTION_IF_NULL(value_type);
  396. auto value_type_str = value_type->ToString();
  397. if (value_type_str == "Int32") {
  398. int data = GetValue<int>(value);
  399. attr_value.push_back(data);
  400. } else {
  401. attr_value = GetValue<std::vector<int>>(value);
  402. }
  403. (*attr_obj)["value"] = attr_value;
  404. } else if (type == "listListInt") {
  405. auto attr_value = GetValue<std::vector<std::vector<int>>>(value);
  406. (*attr_obj)["value"] = attr_value;
  407. } else {
  408. MS_LOG(EXCEPTION) << "type: " << type << "not support";
  409. }
  410. }
  411. bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
  412. std::vector<size_t> *output_size_list) {
  413. if (input_size_list == nullptr || output_size_list == nullptr) {
  414. MS_LOG(ERROR) << "input size or output size is nullptr";
  415. return false;
  416. }
  417. input_size_list->clear();
  418. output_size_list->clear();
  419. for (size_t i = 0; i < kernel_json["op_info"]["inputs"].size(); i++) {
  420. for (size_t m = 0; m < kernel_json["op_info"]["inputs"][i].size(); m++) {
  421. size_t size_i = 1;
  422. if (kernel_json["op_info"]["inputs"][i][m]["valid"] == false) {
  423. std::string input_name = kernel_json["op_info"]["inputs"][i][m]["name"];
  424. MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
  425. continue;
  426. }
  427. for (const auto &j : kernel_json["op_info"]["inputs"][i][m]["shape"]) {
  428. size_i *= static_cast<size_t>(j);
  429. }
  430. std::string dtype = kernel_json["op_info"]["inputs"][i][m]["dtype"];
  431. size_t nbyte = tbe::GetDtypeNbyte(dtype);
  432. size_i *= nbyte;
  433. input_size_list->push_back(size_i);
  434. }
  435. }
  436. for (size_t i = 0; i < kernel_json["op_info"]["outputs"].size(); i++) {
  437. for (size_t m = 0; m < kernel_json["op_info"]["outputs"][i].size(); m++) {
  438. size_t size_i = 1;
  439. if (kernel_json["op_info"]["outputs"][i][m]["valid"] == false) {
  440. std::string output_name = kernel_json["op_info"]["outputs"][i][m]["name"];
  441. MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
  442. continue;
  443. }
  444. for (const auto &j : kernel_json["op_info"]["outputs"][i][m]["shape"]) {
  445. size_i *= static_cast<size_t>(j);
  446. }
  447. std::string dtype = kernel_json["op_info"]["outputs"][i][m]["dtype"];
  448. size_t nbyte = tbe::GetDtypeNbyte(dtype);
  449. size_i *= nbyte;
  450. output_size_list->push_back(size_i);
  451. }
  452. }
  453. return true;
  454. }
  455. bool TbeKernelBuild::GenFusionScopeJson(const vector<mindspore::AnfNodePtr> &input_nodes,
  456. const vector<mindspore::AnfNodePtr> &compute_nodes, nlohmann::json *fusion_str,
  457. std::string *fusion_kernel) {
  458. MS_EXCEPTION_IF_NULL(fusion_str);
  459. MS_EXCEPTION_IF_NULL(fusion_kernel);
  460. // get input layer info
  461. std::vector<std::vector<mindspore::AnfNodePtr>> input_layers;
  462. if (!GetInputLayers(input_nodes, compute_nodes, &input_layers)) {
  463. return false;
  464. }
  465. // gen fusion scopre_op jsom
  466. vector<nlohmann::json> compute_list;
  467. (*fusion_kernel) = kFusionKernelNamePrfix;
  468. // index: fusion build option input record, next one from 0
  469. static size_t index = 0;
  470. auto layer_iter = input_layers.begin();
  471. auto compute_op_iter = compute_nodes.begin();
  472. for (; compute_op_iter != compute_nodes.end(); ++compute_op_iter, ++layer_iter) {
  473. nlohmann::json compute_op_str;
  474. (void)GenFusionComputeJson(*compute_op_iter, &layer_iter, &compute_op_str, fusion_kernel, &index);
  475. compute_list.push_back(compute_op_str);
  476. }
  477. index = 0;
  478. // gen data input json
  479. vector<nlohmann::json> data_list;
  480. for (const auto &layer : input_layers) {
  481. for (const auto &data_input : layer) {
  482. nlohmann::json data_str;
  483. if (!GenFusionDataInputJson(data_input, &data_str, &index)) {
  484. MS_LOG(DEBUG) << "GenFusionDataInputJson faild.";
  485. return false;
  486. }
  487. data_list.push_back(data_str);
  488. }
  489. }
  490. index = 0;
  491. data_list.insert(data_list.end(), compute_list.begin(), compute_list.end());
  492. (*fusion_str)[kFusionOpList] = data_list;
  493. return true;
  494. }
  495. void TbeKernelBuild::GenDescJson(const std::shared_ptr<mindspore::AnfNode> &anf_node, size_t node_out_idx,
  496. size_t desc_output_idx, nlohmann::json *output_desc) {
  497. std::string output_desc_name = anf_node->fullname_with_scope();
  498. if (node_out_idx > 0) {
  499. output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx);
  500. }
  501. (*output_desc)["name"] = NormalizeFullScopeName(output_desc_name);
  502. auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx);
  503. (*output_desc)["data_type"] = tbe::TypeIdToString(type_id);
  504. auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx);
  505. if (ori_shape.empty()) {
  506. ori_shape.emplace_back(1);
  507. }
  508. (*output_desc)["ori_shape"] = ori_shape;
  509. auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx);
  510. if (shape.empty()) {
  511. shape.emplace_back(1);
  512. }
  513. (*output_desc)["shape"] = shape;
  514. auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx);
  515. if (format == kOpFormat_DEFAULT) {
  516. if (ori_shape.size() == 4) {
  517. format = kOpFormat_NCHW;
  518. } else {
  519. format = kOpFormat_ND;
  520. }
  521. }
  522. (*output_desc)["format"] = format;
  523. (*output_desc)["ori_format"] = kOpFormat_NCHW;
  524. (*output_desc)["output_index"] = desc_output_idx;
  525. }
  526. void TbeKernelBuild::GenReusedOutputDesc(const shared_ptr<mindspore::AnfNode> &anf_node, size_t index,
  527. size_t output_index, nlohmann::json *output_desc) {
  528. std::string output_desc_name = anf_node->fullname_with_scope() + "_" + std::to_string(index);
  529. (*output_desc)["name"] = NormalizeFullScopeName(output_desc_name);
  530. (*output_desc)["data_type"] = tbe::TypeIdToString(kNumberTypeFloat32);
  531. (*output_desc)["output_index"] = output_index;
  532. std::vector<size_t> shape;
  533. (*output_desc)["shape"] = shape;
  534. }
  535. bool TbeKernelBuild::GetInputLayers(const vector<mindspore::AnfNodePtr> &input_nodes,
  536. const vector<mindspore::AnfNodePtr> &compute_nodes,
  537. std::vector<std::vector<mindspore::AnfNodePtr>> *input_layers) {
  538. size_t input_size = 0;
  539. for (const auto &compute_node : compute_nodes) {
  540. std::vector<mindspore::AnfNodePtr> layer;
  541. MS_EXCEPTION_IF_NULL(compute_node);
  542. auto ccompute_node = compute_node->cast<CNodePtr>();
  543. if (ccompute_node == nullptr) {
  544. MS_LOG(DEBUG) << "fusion compute node must be cnode";
  545. return false;
  546. }
  547. for (size_t i = 1; i < ccompute_node->inputs().size(); ++i) {
  548. auto input = ccompute_node->input(i);
  549. auto find_iter = std::find(input_nodes.begin(), input_nodes.end(), input);
  550. if (find_iter != input_nodes.end()) {
  551. layer.emplace_back((*find_iter));
  552. }
  553. }
  554. input_size += layer.size();
  555. input_layers->emplace_back(layer);
  556. }
  557. if (input_nodes.size() != input_size) {
  558. MS_LOG(DEBUG) << "fusion scope error, layer input:" << input_size << ", input_node:" << input_nodes.size();
  559. return false;
  560. }
  561. return true;
  562. }
  563. bool TbeKernelBuild::GenFusionDataInputJson(const shared_ptr<mindspore::AnfNode> &data_input, nlohmann::json *data_str,
  564. size_t *index) {
  565. MS_EXCEPTION_IF_NULL(data_str);
  566. MS_EXCEPTION_IF_NULL(index);
  567. std::vector<nlohmann::json> output_desc_list;
  568. if (!data_input) {
  569. MS_LOG(INFO) << "data input is optional node";
  570. auto name = std::string(kOptional) + std::to_string(*index);
  571. (*data_str)["name"] = name;
  572. nlohmann::json output_desc;
  573. output_desc["name"] = name;
  574. output_desc["shape"] = "NULL";
  575. output_desc_list.push_back(output_desc);
  576. (*index)++;
  577. } else {
  578. auto kernel_idx = AnfAlgo::VisitKernel(data_input, 0);
  579. auto real_node = kernel_idx.first;
  580. size_t real_idx = kernel_idx.second;
  581. MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx;
  582. // "output_desc"
  583. nlohmann::json output_desc;
  584. GenDescJson(real_node, real_idx, real_idx, &output_desc);
  585. output_desc_list.push_back(output_desc);
  586. (*data_str)["name"] = NormalizeFullScopeName(real_node->fullname_with_scope());
  587. }
  588. (*data_str)["output_desc"] = output_desc_list;
  589. (*data_str)["type"] = "Data";
  590. return true;
  591. }
  592. bool TbeKernelBuild::IsDynamicInput(const mindspore::CNodePtr &cnode) {
  593. MS_EXCEPTION_IF_NULL(cnode);
  594. auto primitive = AnfAlgo::GetCNodePrimitive(cnode);
  595. MS_EXCEPTION_IF_NULL(primitive);
  596. // for dynamic input number, dyn_input_sizes has the info of dynamic input num for each input.
  597. bool ret = false;
  598. std::vector<int> dyn_input_sizes;
  599. auto dynamic_input_attr = primitive->GetAttr(kAttrDynInputSizes);
  600. if (dynamic_input_attr != nullptr) {
  601. dyn_input_sizes = GetValue<const std::vector<int>>(dynamic_input_attr);
  602. auto real_input_size = cnode->inputs().size() - 1;
  603. auto dyn_input_size = dyn_input_sizes.size();
  604. if (dyn_input_size != 1) {
  605. MS_LOG(DEBUG) << "fusion build not support dyn_input_sizes > 1";
  606. return ret;
  607. }
  608. if (IntToSize(dyn_input_sizes[0]) != real_input_size) {
  609. MS_LOG(DEBUG) << " dyn_input_size" << dyn_input_sizes[0] << "not equal real_input_size" << real_input_size;
  610. return ret;
  611. }
  612. ret = true;
  613. }
  614. return ret;
  615. }
  616. size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool is_dynamic_input) {
  617. if (is_dynamic_input) {
  618. return 0;
  619. }
  620. MS_EXCEPTION_IF_NULL(cnode);
  621. auto node_name = AnfAlgo::GetCNodeName(cnode);
  622. auto op_info = OpLib::FindOp(node_name, kTBE);
  623. MS_EXCEPTION_IF_NULL(cnode);
  624. if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
  625. MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
  626. }
  627. return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size());
  628. }
  629. bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode,
  630. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  631. std::vector<nlohmann::json> *input_desc_list, size_t *index) {
  632. MS_EXCEPTION_IF_NULL(cnode);
  633. MS_EXCEPTION_IF_NULL(input_desc_list);
  634. bool is_dynamic_input = IsDynamicInput(cnode);
  635. for (size_t i = 1; i < cnode->inputs().size(); ++i) {
  636. auto input = cnode->input(i);
  637. auto kernel_idx = AnfAlgo::VisitKernel(input, 0);
  638. auto real_node = kernel_idx.first;
  639. size_t real_idx = kernel_idx.second;
  640. MS_LOG(INFO) << "real name" << real_node->fullname_with_scope() << "index:" << real_idx;
  641. nlohmann::json input_desc;
  642. GenDescJson(real_node, real_idx, real_idx, &input_desc);
  643. if (is_dynamic_input) {
  644. MS_LOG(INFO) << "node has dynamic input.";
  645. input_desc["dyn_index"] = (i - 1);
  646. }
  647. (*input_desc_list).emplace_back(input_desc);
  648. }
  649. size_t optional_num = GetOptionalInput(cnode, is_dynamic_input);
  650. if (optional_num > 0) {
  651. MS_LOG(INFO) << "node has optional input.";
  652. for (size_t i = 0; i < optional_num; ++i) {
  653. nlohmann::json optional_input_desc;
  654. optional_input_desc["name"] = std::string(kOptional) + std::to_string(*index);
  655. (*index)++;
  656. (*layer_iter)->emplace_back(nullptr);
  657. (*input_desc_list).emplace_back(optional_input_desc);
  658. }
  659. }
  660. return true;
  661. }
  662. std::vector<size_t> TbeKernelBuild::GetDescOutputIndex(const std::vector<int> &output_used_nums) {
  663. std::vector<size_t> desc_output_index = {};
  664. bool find_reused = false;
  665. size_t reused_num = 0;
  666. for (size_t idx = 0; idx < output_used_nums.size(); ++idx) {
  667. auto output_use_num_item = output_used_nums[idx];
  668. MS_LOG(INFO) << "output used num[" << idx << "] = " << output_use_num_item;
  669. if (output_use_num_item == 1 || output_use_num_item == 0) {
  670. desc_output_index.emplace_back(idx);
  671. } else {
  672. if (!find_reused) {
  673. desc_output_index.emplace_back(idx);
  674. } else {
  675. desc_output_index.emplace_back(desc_output_index[idx - 1]);
  676. }
  677. reused_num += (output_use_num_item - 1);
  678. find_reused = true;
  679. }
  680. }
  681. auto pad_value = output_used_nums.size() == 1 ? 0 : desc_output_index[desc_output_index.size() - 1] + 1;
  682. for (size_t i = 0; i < reused_num; ++i) {
  683. desc_output_index.emplace_back(pad_value);
  684. }
  685. return desc_output_index;
  686. }
  687. bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode,
  688. std::vector<nlohmann::json> *output_desc_list) {
  689. auto output_size = AnfAlgo::GetOutputTensorNum(cnode);
  690. if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) {
  691. auto output_used_nums = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode, kAttrOutputUsedNum);
  692. MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope();
  693. if (output_used_nums.size() != output_size) {
  694. MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")"
  695. << " is not match output used num(" << output_used_nums.size() << ")";
  696. return false;
  697. }
  698. auto desc_output_index = GetDescOutputIndex(output_used_nums);
  699. for (size_t i = 0; i < output_size; ++i) {
  700. MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i];
  701. nlohmann::json output_desc;
  702. GenDescJson(cnode, i, desc_output_index[i], &output_desc);
  703. output_desc_list->emplace_back(output_desc);
  704. }
  705. for (size_t j = output_size; j < desc_output_index.size(); ++j) {
  706. MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j];
  707. nlohmann::json output_desc;
  708. GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc);
  709. output_desc_list->emplace_back(output_desc);
  710. }
  711. } else {
  712. for (size_t i = 0; i < output_size; ++i) {
  713. nlohmann::json output_desc;
  714. GenDescJson(cnode, i, i, &output_desc);
  715. output_desc_list->push_back(output_desc);
  716. }
  717. }
  718. return true;
  719. }
  720. bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node,
  721. std::vector<std::vector<mindspore::AnfNodePtr>>::iterator *layer_iter,
  722. nlohmann::json *compute_op_str, std::string *fusion_kernel_name,
  723. size_t *index) {
  724. MS_EXCEPTION_IF_NULL(compute_node);
  725. auto cnode = compute_node->cast<CNodePtr>();
  726. MS_EXCEPTION_IF_NULL(cnode);
  727. // gen input desc
  728. std::vector<nlohmann::json> input_desc_list;
  729. (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index);
  730. (*compute_op_str)["input_desc"] = input_desc_list;
  731. // gen output desc
  732. std::vector<nlohmann::json> output_desc_list;
  733. if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) {
  734. MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope();
  735. return false;
  736. }
  737. (*compute_op_str)["output_desc"] = output_desc_list;
  738. // gen others
  739. auto type = AnfAlgo::GetCNodeName(cnode);
  740. if (type == "TensorAdd") {
  741. type = "Add";
  742. }
  743. (*compute_op_str)["type"] = type;
  744. tbe::TbeAdapter::NormalizeFuncName(&type);
  745. (*compute_op_str)["func_name"] = type;
  746. (*compute_op_str)["name"] = NormalizeFullScopeName(cnode->fullname_with_scope());
  747. (void)(*fusion_kernel_name).append("_");
  748. (void)(*fusion_kernel_name).append(type);
  749. return true;
  750. }
  751. size_t TbeKernelBuild::GetIOSizeImpl(const nlohmann::json &desc) {
  752. size_t ret = 1;
  753. for (const auto &shape_item : desc["shape"]) {
  754. ret *= static_cast<size_t>(shape_item);
  755. }
  756. std::string data_type = desc["data_type"];
  757. size_t nbyte = tbe::GetDtypeNbyte(data_type);
  758. ret *= nbyte;
  759. return ret;
  760. }
  761. bool TbeKernelBuild::GetIOSize(const nlohmann::json &fusion_op_list, const vector<mindspore::AnfNodePtr> &output_nodes,
  762. std::vector<size_t> *input_size_list, std::vector<size_t> *output_size_list) {
  763. MS_EXCEPTION_IF_NULL(input_size_list);
  764. MS_EXCEPTION_IF_NULL(output_size_list);
  765. input_size_list->clear();
  766. output_size_list->clear();
  767. for (const auto &op : fusion_op_list) {
  768. if (op["type"] == "Data") {
  769. const auto &data_output_desc = op["output_desc"];
  770. for (const auto &data_output : data_output_desc) {
  771. if (data_output["shape"] == "NULL") {
  772. break;
  773. }
  774. auto ret = GetIOSizeImpl(data_output);
  775. input_size_list->push_back(ret);
  776. }
  777. }
  778. }
  779. for (const auto &output_node : output_nodes) {
  780. auto kernel_idx = AnfAlgo::VisitKernel(output_node, 0);
  781. auto real_node = kernel_idx.first;
  782. size_t real_idx = kernel_idx.second;
  783. for (const auto &op : fusion_op_list) {
  784. auto normal_name = NormalizeFullScopeName(real_node->fullname_with_scope());
  785. if (op["name"] == normal_name) {
  786. auto op_output_desces = op["output_desc"];
  787. if (output_node != real_node) {
  788. // tuple_get item
  789. MS_LOG(DEBUG) << "output is a tuple getitem node";
  790. auto output_desc = op_output_desces[real_idx];
  791. if (output_desc["shape"].empty()) {
  792. continue;
  793. }
  794. auto ret = GetIOSizeImpl(output_desc);
  795. output_size_list->push_back(ret);
  796. } else {
  797. for (const auto &output_desc : op_output_desces) {
  798. if (output_desc["shape"].empty()) {
  799. continue;
  800. }
  801. auto ret = GetIOSizeImpl(output_desc);
  802. output_size_list->push_back(ret);
  803. }
  804. }
  805. }
  806. }
  807. }
  808. return true;
  809. }
  810. } // namespace kernel
  811. } // namespace mindspore