You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tbe_adapter.cc 16 kB

6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/tbe/tbe_adapter.h"
  17. #include <map>
  18. #include <unordered_set>
  19. #include <string>
  20. #include <memory>
  21. #include <vector>
  22. #include <algorithm>
  23. #include "session/anf_runtime_algorithm.h"
  24. #include "kernel/oplib/opinfo.h"
  25. namespace mindspore {
  26. namespace kernel {
  27. namespace tbe {
  28. static std::map<string, string> tbe_func_adapter_map = {
  29. {"softmax", "softmax_v2"},
  30. {"log_softmax", "log_softmax_v2"},
  31. {"apply_momentum", "apply_momentum_d"},
  32. {"apply_ftrl", "apply_ftrl_d"},
  33. {"re_lu6", "relu6"},
  34. {"re_lu6_grad", "relu6_grad"},
  35. {"re_lu", "relu"},
  36. {"re_luv2", "relu_v2"},
  37. {"p_re_lu", "prelu"},
  38. {"p_re_lu_grad", "prelu_grad"},
  39. {"tensor_add", "add"},
  40. {"reduce_mean", "reduce_mean_d"},
  41. {"reduce_max", "reduce_max_d"},
  42. {"reduce_min", "reduce_min_d"},
  43. {"avg_pool_grad", "avg_pool_grad_d"},
  44. {"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
  45. {"conv2d_backprop_input", "conv2d_backprop_input_d"},
  46. {"depthwise_conv2d_native", "depthwise_conv2d"},
  47. {"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"},
  48. {"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"},
  49. {"scatter_nd", "scatter_nd_d"},
  50. {"tile", "tile_d"},
  51. {"gather_v2", "gather_v2_d"},
  52. {"sparse_gather_v2", "gather_v2_d"},
  53. {"batch_mat_mul", "batch_matmul"},
  54. {"b_n_training_reduce", "bn_training_reduce"},
  55. {"b_n_training_update", "bn_training_update"},
  56. {"b_n_training_update_v2", "bn_training_update_v2"},
  57. {"b_n_training_update_v3", "bn_training_update_v3"},
  58. {"b_n_training_reduce_grad", "bn_training_reduce_grad"},
  59. {"b_n_training_update_grad", "bn_training_update_grad"},
  60. {"b_n_infer", "bn_infer"},
  61. {"b_n_infer_grad", "bn_infer_grad"},
  62. {"n_pu_clear_float_status", "n_p_u_clear_float_status"},
  63. {"n_pu_get_float_status", "n_p_u_get_float_status"},
  64. {"n_pu_alloc_float_status", "n_p_u_alloc_float_status"},
  65. {"dropout_do_mask", "drop_out_do_mask"},
  66. {"strided_slice", "strided_slice_d"},
  67. {"strided_slice_grad", "strided_slice_grad_d"},
  68. {"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
  69. {"apply_ada_max", "apply_ada_max_d"},
  70. {"apply_adadelta", "apply_adadelta_d"},
  71. {"apply_adagrad", "apply_adagrad_d"},
  72. {"apply_adagrad_v2", "apply_adagradv2_d"},
  73. {"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
  74. {"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
  75. {"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
  76. {"transpose", "transpose_d"},
  77. {"fill", "fill_d"},
  78. {"unsorted_segment_sum", "unsorted_segment_sum_d"},
  79. {"concat", "concat_d"},
  80. {"slice", "slice_d"},
  81. {"reduce_sum", "reduce_sum_d"},
  82. {"inplace_add", "inplace_add_d"},
  83. {"inplace_sub", "inplace_sub_d"},
  84. {"one_hot", "one_hot_d"},
  85. {"sum", "reduce_sum_d"},
  86. {"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
  87. {"lamb_next_mv", "lamb_next_m_v"},
  88. {"split", "split_d"},
  89. {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},
  90. {"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
  91. {"pad", "pad_d"},
  92. {"argmax", "arg_max_d"},
  93. {"argmin", "arg_min_d"},
  94. {"space_to_batch", "space_to_batch_d"},
  95. {"batch_to_space", "batch_to_space_d"},
  96. {"space_to_batch_nd", "space_to_batch_nd_d"},
  97. {"batch_to_space_nd", "batch_to_space_nd_d"},
  98. {"resize_bilinear", "resize_bilinear_v2_d"},
  99. {"resize_bilinear_grad", "resize_bilinear_v2_grad"},
  100. {"adam", "apply_adam_d"},
  101. {"r_oi_align", "roi_align"},
  102. {"r_oi_align_grad", "roi_align_grad"},
  103. {"i_ou", "iou"},
  104. {"s_gd", "sgd"},
  105. {"l_ars_update", "lars_v2_update"},
  106. {"n_ms_with_mask", "nms_with_mask"},
  107. {"square_sum_all", "square_sum_all"},
  108. {"cum_sum", "cumsum_d"},
  109. {"range", "range_d"},
  110. {"inv_grad", "inv_grad"},
  111. {"apply_rms_prop", "apply_rms_prop_d"},
  112. {"cum_prod", "cumprod_d"},
  113. {"reduce_all", "reduce_all_d"},
  114. {"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
  115. {"unsorted_segment_min", "unsorted_segment_min_d"},
  116. {"reduce_prod", "reduce_prod_d"},
  117. {"a_cos", "acos"},
  118. {"a_cos_grad", "acos_grad"},
  119. {"histogram_fixed_width", "histogram_fixed_width_d"},
  120. {"broadcast_to", "broadcast_to_d"},
  121. {"inplace_update", "inplace_update_d"}};
  122. void TbeAdapter::NormalizeFuncName(std::string *func_name) {
  123. if (func_name == nullptr) {
  124. MS_LOG(EXCEPTION) << "func_name is null";
  125. }
  126. std::string name_tmp;
  127. bool sub_head = false;
  128. for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) {
  129. if (islower(*iter)) {
  130. sub_head = false;
  131. }
  132. if (isdigit(*iter)) {
  133. sub_head = true;
  134. }
  135. if (isupper(*iter) && iter != func_name->begin()) {
  136. if (!sub_head) {
  137. (void)name_tmp.insert(name_tmp.end(), '_');
  138. sub_head = true;
  139. } else {
  140. string::iterator iter_next = iter + 1;
  141. if (iter_next != func_name->end()) {
  142. if (islower(*iter_next)) {
  143. (void)name_tmp.insert(name_tmp.end(), '_');
  144. }
  145. }
  146. }
  147. }
  148. (void)name_tmp.insert(name_tmp.end(), *iter);
  149. }
  150. (void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower);
  151. *func_name = name_tmp;
  152. auto iter = tbe_func_adapter_map.find(*func_name);
  153. if (iter != tbe_func_adapter_map.end()) {
  154. MS_LOG(INFO) << "map actual op from me " << *func_name << " to tbe op" << iter->second;
  155. *func_name = iter->second;
  156. }
  157. }
  158. void TbeAdapter::SetTbeAttrsForTransDataOp(const mindspore::AnfNodePtr &anf_node) {
  159. MS_EXCEPTION_IF_NULL(anf_node);
  160. if (AnfAlgo::GetCNodeName(anf_node) == kTransDataOpName) {
  161. std::string input_format = AnfAlgo::GetInputFormat(anf_node, 0);
  162. std::string output_format = AnfAlgo::GetOutputFormat(anf_node, 0);
  163. if (input_format == kOpFormat_DEFAULT) {
  164. input_format = kOpFormat_NCHW;
  165. }
  166. if (output_format == kOpFormat_DEFAULT) {
  167. output_format = kOpFormat_NCHW;
  168. }
  169. AnfAlgo::SetNodeAttr("src_format", MakeValue(input_format), anf_node);
  170. AnfAlgo::SetNodeAttr("dst_format", MakeValue(output_format), anf_node);
  171. }
  172. }
  173. std::unordered_set<std::string> input_order_adjusted_ops = {
  174. "Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
  175. "LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};
  176. void TbeAdapter::InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
  177. nlohmann::json *inputs_json) {
  178. MS_EXCEPTION_IF_NULL(inputs_json);
  179. if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
  180. (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
  181. } else {
  182. if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
  183. inputs_json->push_back(inputs_list[2]);
  184. inputs_json->push_back(inputs_list[0]);
  185. inputs_json->push_back(inputs_list[1]);
  186. for (size_t i = 3; i < inputs_list.size(); ++i) {
  187. inputs_json->push_back(inputs_list[i]);
  188. }
  189. } else if (op_name == "ApplyCenteredRMSProp") {
  190. // Parameter order of ApplyCenteredRMSProp's TBE implementation is different from python API, so map
  191. // TBE parameter to correspond python API parameter by latter's index using hardcode
  192. inputs_json->push_back(inputs_list[0]);
  193. inputs_json->push_back(inputs_list[1]);
  194. inputs_json->push_back(inputs_list[2]);
  195. inputs_json->push_back(inputs_list[3]);
  196. inputs_json->push_back(inputs_list[5]);
  197. inputs_json->push_back(inputs_list[6]);
  198. inputs_json->push_back(inputs_list[7]);
  199. inputs_json->push_back(inputs_list[8]);
  200. inputs_json->push_back(inputs_list[4]);
  201. } else {
  202. inputs_json->push_back(inputs_list[1]);
  203. inputs_json->push_back(inputs_list[0]);
  204. for (size_t i = 2; i < inputs_list.size(); ++i) {
  205. inputs_json->push_back(inputs_list[i]);
  206. }
  207. }
  208. }
  209. }
  210. void TbeAdapter::FusionInputOrderPass(const std::string &op_name, const std::vector<nlohmann::json> &inputs_list,
  211. std::vector<nlohmann::json> *inputs_json) {
  212. MS_EXCEPTION_IF_NULL(inputs_json);
  213. if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
  214. (void)std::copy(inputs_list.begin(), inputs_list.end(), std::back_inserter((*inputs_json)));
  215. } else {
  216. if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
  217. inputs_json->emplace_back(inputs_list[2]);
  218. inputs_json->emplace_back(inputs_list[0]);
  219. inputs_json->emplace_back(inputs_list[1]);
  220. for (size_t i = 3; i < inputs_list.size(); ++i) {
  221. inputs_json->emplace_back(inputs_list[i]);
  222. }
  223. } else {
  224. inputs_json->emplace_back(inputs_list[1]);
  225. inputs_json->emplace_back(inputs_list[0]);
  226. for (size_t i = 2; i < inputs_list.size(); ++i) {
  227. inputs_json->emplace_back(inputs_list[i]);
  228. }
  229. }
  230. }
  231. }
  232. void TbeAdapter::FusionDataOrderPass(const std::string &op_name, const std::vector<AnfNodePtr> &data_layer,
  233. std::vector<AnfNodePtr> *reorder_data_layer) {
  234. MS_EXCEPTION_IF_NULL(reorder_data_layer);
  235. if (input_order_adjusted_ops.find(op_name) == input_order_adjusted_ops.end()) {
  236. (void)std::copy(data_layer.begin(), data_layer.end(), std::back_inserter((*reorder_data_layer)));
  237. } else {
  238. if (op_name == "MinimumGrad" || op_name == "MaximumGrad") {
  239. reorder_data_layer->emplace_back(data_layer[2]);
  240. reorder_data_layer->emplace_back(data_layer[0]);
  241. reorder_data_layer->emplace_back(data_layer[1]);
  242. for (size_t i = 3; i < data_layer.size(); ++i) {
  243. reorder_data_layer->emplace_back(data_layer[i]);
  244. }
  245. } else {
  246. reorder_data_layer->emplace_back(data_layer[1]);
  247. reorder_data_layer->emplace_back(data_layer[0]);
  248. for (size_t i = 2; i < data_layer.size(); ++i) {
  249. reorder_data_layer->emplace_back(data_layer[i]);
  250. }
  251. }
  252. }
  253. }
  254. std::map<std::string, FAttrsPass> TbeAdapter::build_json_attr_pass_map_ = {
  255. {"MaximumGrad", TbeAdapter::MaximumGradAttrJsonPass},
  256. {"MinimumGrad", TbeAdapter::MinimumGradAttrJsonPass},
  257. {"Cast", TbeAdapter::CastAttrJsonPass}};
  258. bool TbeAdapter::RunAttrPass(const mindspore::AnfNodePtr &anf_node,
  259. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  260. nlohmann::json *attrs_json) {
  261. MS_EXCEPTION_IF_NULL(attrs_json);
  262. auto cnode_name = AnfAlgo::GetCNodeName(anf_node);
  263. auto FPass = build_json_attr_pass_map_.find(cnode_name);
  264. if (FPass != build_json_attr_pass_map_.end()) {
  265. FPass->second(anf_node, op_info_attrs, attrs_json);
  266. return true;
  267. }
  268. return false;
  269. }
  270. void TbeAdapter::MaximumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
  271. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  272. nlohmann::json *attrs_json) {
  273. MS_EXCEPTION_IF_NULL(anf_node);
  274. MS_EXCEPTION_IF_NULL(attrs_json);
  275. auto attr_num = op_info_attrs.size();
  276. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  277. MS_EXCEPTION_IF_NULL(primitive);
  278. for (size_t i = 0; i < attr_num; i++) {
  279. nlohmann::json attr_obj;
  280. MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
  281. std::string attr_name = op_info_attrs[i]->name();
  282. auto value = primitive->GetAttr(attr_name);
  283. if (value != nullptr) {
  284. bool attr_value = GetValue<bool>(value);
  285. attr_obj["value"] = attr_value;
  286. attr_obj["valid"] = true;
  287. } else {
  288. attr_obj["valid"] = false;
  289. }
  290. attr_obj["name"] = attr_name;
  291. attrs_json->push_back(attr_obj);
  292. }
  293. MS_LOG(INFO) << "MaximumGradAttrJsonPass done.";
  294. }
  295. void TbeAdapter::MinimumGradAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
  296. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  297. nlohmann::json *attrs_json) {
  298. MS_EXCEPTION_IF_NULL(anf_node);
  299. MS_EXCEPTION_IF_NULL(attrs_json);
  300. auto attr_num = op_info_attrs.size();
  301. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  302. MS_EXCEPTION_IF_NULL(primitive);
  303. for (size_t i = 0; i < attr_num; i++) {
  304. nlohmann::json attr_obj;
  305. MS_EXCEPTION_IF_NULL(op_info_attrs[i]);
  306. std::string attr_name = op_info_attrs[i]->name();
  307. auto value = primitive->GetAttr(attr_name);
  308. if (value != nullptr) {
  309. bool attr_value = GetValue<bool>(value);
  310. attr_obj["value"] = attr_value;
  311. attr_obj["valid"] = true;
  312. } else {
  313. attr_obj["valid"] = false;
  314. }
  315. attr_obj["name"] = attr_name;
  316. attrs_json->push_back(attr_obj);
  317. }
  318. MS_LOG(INFO) << "MinimumGradAttrJsonPass done.";
  319. }
  320. static int TypeStrToDstType(const std::string &type_str) {
  321. int ret = -1;
  322. if (type_str == "Float" || type_str == "Float32") {
  323. ret = 0;
  324. } else if (type_str == "Float16") {
  325. ret = 1;
  326. } else if (type_str == "Int8") {
  327. ret = 2;
  328. } else if (type_str == "Int32") {
  329. ret = 3;
  330. } else if (type_str == "UInt8") {
  331. ret = 4;
  332. } else if (type_str == "UInt64") {
  333. ret = 10;
  334. } else if (type_str == "Bool_") {
  335. ret = 12;
  336. } else {
  337. MS_LOG(INFO) << "Error type str is invailed: " << type_str;
  338. }
  339. return ret;
  340. }
  341. void TbeAdapter::CastAttrJsonPass(const mindspore::AnfNodePtr &anf_node,
  342. const std::vector<std::shared_ptr<mindspore::kernel::OpAttr>> &op_info_attrs,
  343. nlohmann::json *attrs_json) {
  344. MS_EXCEPTION_IF_NULL(anf_node);
  345. MS_EXCEPTION_IF_NULL(attrs_json);
  346. if (op_info_attrs.size() != 1) {
  347. MS_LOG(INFO) << "cast node should has dst_type attr";
  348. return;
  349. }
  350. auto attr_name = op_info_attrs[0]->name();
  351. auto type_ptr = std::make_shared<TensorType>(TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, 0)));
  352. MS_EXCEPTION_IF_NULL(type_ptr);
  353. auto type_element = type_ptr->element();
  354. MS_EXCEPTION_IF_NULL(type_element);
  355. auto dtype = type_element->ToString();
  356. auto dst_type_value = TypeStrToDstType(dtype);
  357. nlohmann::json attr_obj;
  358. attr_obj["value"] = dst_type_value;
  359. attr_obj["valid"] = true;
  360. attr_obj["name"] = attr_name;
  361. attrs_json->push_back(attr_obj);
  362. MS_LOG(INFO) << "CastAttrJsonPass done.";
  363. }
  364. void TbeAdapter::GenTopKV2IndicesTensorInfo(const std::shared_ptr<mindspore::AnfNode> &anf_node,
  365. size_t real_input_index, std::vector<nlohmann::json> *input_list,
  366. mindspore::kernel::kCreaterType creater_type) {
  367. MS_EXCEPTION_IF_NULL(anf_node);
  368. MS_EXCEPTION_IF_NULL(input_list);
  369. auto input_x_shape = AnfAlgo::GetOutputInferShape(anf_node, 0);
  370. size_t last_dim = input_x_shape[input_x_shape.size() - 1];
  371. std::vector<size_t> tensor_shape = {last_dim};
  372. std::vector<size_t> tensor_origin_shape = {last_dim};
  373. std::string tensor_format = AnfAlgo::GetInputFormat(anf_node, static_cast<const size_t &>(real_input_index));
  374. if (tensor_format == kOpFormat_DEFAULT) {
  375. tensor_format = kOpFormat_NCHW;
  376. }
  377. std::string tensor_origin_format = kOpFormat_NCHW;
  378. std::string tensor_dtype = "float16";
  379. nlohmann::json input_desc_json;
  380. input_desc_json["dtype"] = tensor_dtype;
  381. input_desc_json["name"] = AnfAlgo::GetCNodeName(anf_node);
  382. input_desc_json["ori_shape"] = tensor_origin_shape;
  383. input_desc_json["ori_format"] = tensor_origin_format;
  384. input_desc_json["shape"] = tensor_shape;
  385. if (creater_type == OP_SELECT_FORMAT) {
  386. input_desc_json["format"] = tensor_origin_format;
  387. } else {
  388. input_desc_json["format"] = tensor_format;
  389. }
  390. input_desc_json["valid"] = true;
  391. input_list->emplace_back(input_desc_json);
  392. }
  393. } // namespace tbe
  394. } // namespace kernel
  395. } // namespace mindspore