You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_adapter.cc 16 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_adapter.h"
  17. #include <map>
  18. #include <unordered_set>
  19. #include <string>
  20. #include <memory>
  21. #include <vector>
  22. #include <algorithm>
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "kernel/oplib/opinfo.h"
  25. namespace mindspore {
  26. namespace kernel {
  27. namespace tbe {
  28. static std::map<string, string> tbe_func_adapter_map = {
  29. {"softmax", "softmax_v2"},
  30. {"log_softmax", "log_softmax_v2"},
  31. {"apply_momentum", "apply_momentum_d"},
  32. {"apply_ftrl", "apply_ftrl_d"},
  33. {"re_lu6", "relu6"},
  34. {"re_lu6_grad", "relu6_grad"},
  35. {"re_lu", "relu"},
  36. {"re_luv2", "relu_v2"},
  37. {"p_re_lu", "prelu"},
  38. {"p_re_lu_grad", "prelu_grad"},
  39. {"tensor_add", "add"},
  40. {"reduce_mean", "reduce_mean_d"},
  41. {"reduce_max", "reduce_max_d"},
  42. {"reduce_min", "reduce_min_d"},
  43. {"avg_pool_grad", "avg_pool_grad_d"},
  44. {"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
  45. {"conv2d_backprop_input", "conv2d_backprop_input_d"},
  46. {"depthwise_conv2d_native", "depthwise_conv2d"},
  47. {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"},
  48. {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"},
  49. {"scatter_nd", "scatter_nd_d"},
  50. {"tile", "tile_d"},
  51. {"gather_v2", "gather_v2_d"},
  52. {"sparse_gather_v2", "gather_v2_d"},
  53. {"batch_mat_mul", "batch_matmul"},
  54. {"b_n_training_reduce", "bn_training_reduce"},
  55. {"b_n_training_update", "bn_training_update"},
  56. {"b_n_training_update_v2", "bn_training_update_v2"},
  57. {"b_n_training_update_v3", "bn_training_update_v3"},
  58. {"b_n_training_reduce_grad", "bn_training_reduce_grad"},
  59. {"b_n_training_update_grad", "bn_training_update_grad"},
  60. {"b_n_infer", "bn_infer"},
  61. {"b_n_infer_grad", "bn_infer_grad"},
  62. {"n_pu_clear_float_status", "n_p_u_clear_float_status"},
  63. {"n_pu_get_float_status", "n_p_u_get_float_status"},
  64. {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"},
  65. {"dropout_do_mask", "drop_out_do_mask"},
  66. {"strided_slice", "strided_slice_d"},
  67. {"strided_slice_grad", "strided_slice_grad_d"},
  68. {"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
  69. {"apply_ada_max", "apply_ada_max_d"},
  70. {"apply_adadelta", "apply_adadelta_d"},
  71. {"apply_adagrad", "apply_adagrad_d"},
  72. {"apply_adagrad_v2", "apply_adagradv2_d"},
  73. {"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
  74. {"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
  75. {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
  76. {"transpose", "transpose_d"},
  77. {"fill", "fill_d"},
  78. {"unsorted_segment_sum", "unsorted_segment_sum_d"},
  79. {"concat", "concat_d"},
  80. {"slice", "slice_d"},
  81. {"reduce_sum", "reduce_sum_d"},
  82. {"inplace_add", "inplace_add_d"},
  83. {"inplace_sub", "inplace_sub_d"},
  84. {"one_hot", "one_hot_d"},
  85. {"sum", "reduce_sum_d"},
  86. {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
  87. {"lamb_next_mv", "lamb_next_m_v"},
  88. {"split", "split_d"},
  89. {"split_v", "split_v_d"},
  90. {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},
  91. {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
  92. {"pad", "pad_d"},
  93. {"argmax", "arg_max_d"},
  94. {"argmin", "arg_min_d"},
  95. {"space_to_batch", "space_to_batch_d"},
  96. {"batch_to_space", "batch_to_space_d"},
  97. {"space_to_batch_nd", "space_to_batch_nd_d"},
  98. {"batch_to_space_nd", "batch_to_space_nd_d"},
  99. {"resize_bilinear", "resize_bilinear_v2_d"},
  100. {"resize_bilinear_grad", "resize_bilinear_v2_grad"},
  101. {"adam", "apply_adam_d"},
  102. {"r_oi_align", "roi_align"},
  103. {"r_oi_align_grad", "roi_align_grad"},
  104. {"i_ou", "iou"},
  105. {"s_gd", "sgd"},
  106. {"l_rn", "lrn"},
  107. {"l_rn_grad", "lrn_grad"},
  108. {"l_ars_update", "lars_v2_update"},
  109. {"n_ms_with_mask", "nms_with_mask"},
  110. {"square_sum_all", "square_sum_all"},
  111. {"cum_sum", "cumsum_d"},
  112. {"range", "range_d"},
  113. {"lin_space", "lin_space_d"},
  114. {"inv_grad", "inv_grad"},
  115. {"apply_rms_prop", "apply_rms_prop_d"},
  116. {"cum_prod", "cumprod_d"},
  117. {"reduce_all", "reduce_all_d"},
  118. {"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
  119. {"unsorted_segment_min", "unsorted_segment_min_d"},
  120. {"reduce_prod", "reduce_prod_d"},
  121. {"a_cos", "acos"},
  122. {"a_cos_grad", "acos_grad"},
  123. {"histogram_fixed_width", "histogram_fixed_width_d"},
  124. {"broadcast_to", "broadcast_to_d"},
  125. {"inplace_update", "inplace_update_d"},
  126. {"matrix_diag", "matrix_diag_d"},
  127. {"matrix_diag_part", "matrix_diag_part_d"},
  128. {"matrix_set_diag", "matrix_set_diag_d"}};
  129. void TbeAdapter::NormalizeFuncName(std::string *func_name) {
  130. if (func_name == nullptr) {
  131. MS_LOG(EXCEPTION) << "func_name is null";
  132. }
  133. std::string name_tmp;
  134. bool sub_head = false;
  135. for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) {
  136. if (islower(*iter)) {
  137. sub_head = false;
  138. }
  139. if (isdigit(*iter)) {
  140. sub_head = true;
  141. }
  142. if (isupper(*iter) && iter != func_name->begin()) {
  143. if (!sub_head) {
  144. (void)name_tmp.insert(name_tmp.end(), '_');
  145. sub_head = true;
  146. } else {
  147. string::iterator iter_next = iter + 1;
  148. if (iter_next != func_name->end()) {
  149. if (islower(*iter_next)) {
  150. (void)name_tmp.insert(name_tmp.end(), '_');
  151. }
  152. }
  153. }
  154. }
  155. (void)name_tmp.insert(name_tmp.end(), *iter);
  156. }
  157. (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower);
  158. *func_name = name_tmp;
  159. auto iter = tbe_func_adapter_map.find(*func_name);
  160. if (iter != tbe_func_adapter_map.end()) {
  161. MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second;
  162. *func_name = iter->second;
  163. }
  164. }
  165. void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) {
  166. MS_EXCEPTION_IF_NULL(anf_node);
  167. if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) {
  168. std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0);
  169. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0);
  170. if (input_format == kOpFormat_DEFAULT) {
  171. input_format = kOpFormat_NCHW;
  172. }
  173. if (output_format == kOpFormat_DEFAULT) {
  174. output_format = kOpFormat_NCHW;
  175. }
  176. AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node);
  177. AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node);
  178. }
  179. }
  180. std::unordered_set<std::string> input_order_adjusted_ops = {
  181. "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
  182. "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};
  183. void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
  184. nlohmann::json *inputs_json) {
  185. MS_EXCEPTION_IF_NULL(inputs_json);
  186. if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
  187. (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
  188. } else {
  189. if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
  190. inputs_json->push_back(inputs_list[2]);
  191. inputs_json->push_back(inputs_list[0]);
  192. inputs_json->push_back(inputs_list[1]);
  193. for (size_t i = 3; i < inputs_list.size(); ++i) {
  194. inputs_json->push_back(inputs_list[i]);
  195. }
  196. } else if (op_name == "ApplyCenteredRMSProp") {
  197. // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
  198. // TBE parameter to correspond python API parameter by latter's index using hardcode
  199. inputs_json->push_back(inputs_list[0]);
  200. inputs_json->push_back(inputs_list[1]);
  201. inputs_json->push_back(inputs_list[2]);
  202. inputs_json->push_back(inputs_list[3]);
  203. inputs_json->push_back(inputs_list[5]);
  204. inputs_json->push_back(inputs_list[6]);
  205. inputs_json->push_back(inputs_list[7]);
  206. inputs_json->push_back(inputs_list[8]);
  207. inputs_json->push_back(inputs_list[4]);
  208. } else {
  209. inputs_json->push_back(inputs_list[1]);
  210. inputs_json->push_back(inputs_list[0]);
  211. for (size_t i = 2; i < inputs_list.size(); ++i) {
  212. inputs_json->push_back(inputs_list[i]);
  213. }
  214. }
  215. }
  216. }
  217. void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector<nlohmann::json> &inputs_list,
  218. std::vector<nlohmann::json> *inputs_json) {
  219. MS_EXCEPTION_IF_NULL(inputs_json);
  220. if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
  221. (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
  222. } else {
  223. if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
  224. inputs_json->emplace_back(inputs_list[2]);
  225. inputs_json->emplace_back(inputs_list[0]);
  226. inputs_json->emplace_back(inputs_list[1]);
  227. for (size_t i = 3; i < inputs_list.size(); ++i) {
  228. inputs_json->emplace_back(inputs_list[i]);
  229. }
  230. } else {
  231. inputs_json->emplace_back(inputs_list[1]);
  232. inputs_json->emplace_back(inputs_list[0]);
  233. for (size_t i = 2; i < inputs_list.size(); ++i) {
  234. inputs_json->emplace_back(inputs_list[i]);
  235. }
  236. }
  237. }
  238. }
  239. void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector<AnfNodePtr> &data_layer,
  240. std::vector<AnfNodePtr> *reorder_data_layer) {
  241. MS_EXCEPTION_IF_NULL(reorder_data_layer);
  242. if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
  243. (void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer)));
  244. } else {
  245. if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
  246. reorder_data_layer->emplace_back(data_layer[2]);
  247. reorder_data_layer->emplace_back(data_layer[0]);
  248. reorder_data_layer->emplace_back(data_layer[1]);
  249. for (size_t i = 3; i < data_layer.size(); ++i) {
  250. reorder_data_layer->emplace_back(data_layer[i]);
  251. }
  252. } else {
  253. reorder_data_layer->emplace_back(data_layer[1]);
  254. reorder_data_layer->emplace_back(data_layer[0]);
  255. for (size_t i = 2; i < data_layer.size(); ++i) {
  256. reorder_data_layer->emplace_back(data_layer[i]);
  257. }
  258. }
  259. }
  260. }
  261. std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = {
  262. {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass},
  263. {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass},
  264. {"Cast", TbeAdapter::CastAttrJsonPass}};
  265. bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node,
  266. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  267. nlohmann::json *attrs_json) {
  268. MS_EXCEPTION_IF_NULL(attrs_json);
  269. auto cnode_name = AnfAlgo::GetCNodeName(anf_node);
  270. auto FPass = build_json_attr_pass_map_.find(cnode_name);
  271. if (FPass != build_json_attr_pass_map_.end()) {
  272. FPass->second(anf_node, op_info_attrs, attrs_json);
  273. return true;
  274. }
  275. return false;
  276. }
  277. void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
  278. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  279. nlohmann::json *attrs_json) {
  280. MS_EXCEPTION_IF_NULL(anf_node);
  281. MS_EXCEPTION_IF_NULL(attrs_json);
  282. auto attr_num = op_info_attrs.size();
  283. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  284. MS_EXCEPTION_IF_NULL(primitive);
  285. for (size_t i = 0; i < attr_num; i++) {
  286. nlohmann::json attr_obj;
  287. MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
  288. std::string attr_name = op_info_attrs[i]->name();
  289. auto value = primitive->GetAttr(attr_name);
  290. if (value != nullptr) {
  291. bool attr_value = GetValue<bool>(value);
  292. attr_obj["value"] = attr_value;
  293. attr_obj["valid"] = true;
  294. } else {
  295. attr_obj["valid"] = false;
  296. }
  297. attr_obj["name"] = attr_name;
  298. attrs_json->push_back(attr_obj);
  299. }
  300. MS_LOG(INFO) << "MaximumGradAttrJsonPass done.";
  301. }
  302. void TbeAdapter::MinimumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
  303. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  304. nlohmann::json *attrs_json) {
  305. MS_EXCEPTION_IF_NULL(anf_node);
  306. MS_EXCEPTION_IF_NULL(attrs_json);
  307. auto attr_num = op_info_attrs.size();
  308. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  309. MS_EXCEPTION_IF_NULL(primitive);
  310. for (size_t i = 0; i < attr_num; i++) {
  311. nlohmann::json attr_obj;
  312. MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
  313. std::string attr_name = op_info_attrs[i]->name();
  314. auto value = primitive->GetAttr(attr_name);
  315. if (value != nullptr) {
  316. bool attr_value = GetValue<bool>(value);
  317. attr_obj["value"] = attr_value;
  318. attr_obj["valid"] = true;
  319. } else {
  320. attr_obj["valid"] = false;
  321. }
  322. attr_obj["name"] = attr_name;
  323. attrs_json->push_back(attr_obj);
  324. }
  325. MS_LOG(INFO) << "MinimumGradAttrJsonPass done.";
  326. }
  327. static int TypeStrToDstType(const std::string &type_str) {
  328. int ret = -1;
  329. if (type_str == "Float" || type_str == "Float32") {
  330. ret = 0;
  331. } else if (type_str == "Float16") {
  332. ret = 1;
  333. } else if (type_str == "Int8") {
  334. ret = 2;
  335. } else if (type_str == "Int32") {
  336. ret = 3;
  337. } else if (type_str == "UInt8") {
  338. ret = 4;
  339. } else if (type_str == "UInt64") {
  340. ret = 10;
  341. } else if (type_str == "Bool") {
  342. ret = 12;
  343. } else {
  344. MS_LOG(INFO) << "Error type str is invailed: " << type_str;
  345. }
  346. return ret;
  347. }
  348. void TbeAdapter::CastAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
  349. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  350. nlohmann::json *attrs_json) {
  351. MS_EXCEPTION_IF_NULL(anf_node);
  352. MS_EXCEPTION_IF_NULL(attrs_json);
  353. if (op_info_attrs.size() != 1) {
  354. MS_LOG(INFO) << "cast node should has dst_type attr";
  355. return;
  356. }
  357. auto attr_name = op_info_attrs[0]->name();
  358. auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, 0)));
  359. MS_EXCEPTION_IF_NULL(type_ptr);
  360. auto type_element = type_ptr->element();
  361. MS_EXCEPTION_IF_NULL(type_element);
  362. auto dtype = type_element->ToString();
  363. auto dst_type_value = TypeStrToDstType(dtype);
  364. nlohmann::json attr_obj;
  365. attr_obj["value"] = dst_type_value;
  366. attr_obj["valid"] = true;
  367. attr_obj["name"] = attr_name;
  368. attrs_json->push_back(attr_obj);
  369. MS_LOG(INFO) << "CastAttrJsonPass done.";
  370. }
  371. void TbeAdapter::GenTopKV2IndicesTensorInfo(const std::shared_ptr<mindspore::AnfNode> &anf_node,
  372. size_t real_input_index, std::vector<nlohmann::json> *input_list,
  373. mindspore::kernel::kCreaterType creater_type) {
  374. MS_EXCEPTION_IF_NULL(anf_node);
  375. MS_EXCEPTION_IF_NULL(input_list);
  376. auto input_x_shape = AnfAlgo::GetOutputInferShape(anf_node, 0);
  377. size_t last_dim = input_x_shape[input_x_shape.size() - 1];
  378. std::vector<size_t> tensor_shape = {last_dim};
  379. std::vector<size_t> tensor_origin_shape = {last_dim};
  380. std::string tensor_format = AnfAlgo::GetInputFormat(anf_node, static_cast<const size_t &>(real_input_index));
  381. if (tensor_format == kOpFormat_DEFAULT) {
  382. tensor_format = kOpFormat_NCHW;
  383. }
  384. std::string tensor_origin_format = kOpFormat_NCHW;
  385. std::string tensor_dtype = "float16";
  386. nlohmann::json input_desc_json;
  387. input_desc_json["dtype"] = tensor_dtype;
  388. input_desc_json["name"] = AnfAlgo::GetCNodeName(anf_node);
  389. input_desc_json["ori_shape"] = tensor_origin_shape;
  390. input_desc_json["ori_format"] = tensor_origin_format;
  391. input_desc_json["shape"] = tensor_shape;
  392. if (creater_type == OP_SELECT_FORMAT) {
  393. input_desc_json["format"] = tensor_origin_format;
  394. } else {
  395. input_desc_json["format"] = tensor_format;
  396. }
  397. input_desc_json["valid"] = true;
  398. input_list->emplace_back(input_desc_json);
  399. }
  400. } // namespace tbe
  401. } // namespace kernel
  402. } // namespace mindspore