You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_fusion_custom_parser_adapter.cc 4.2 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parser/tensorflow/tensorflow_fusion_custom_parser_adapter.h"
  17. #include "common/util.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "parser/common/op_parser_factory.h"
  20. #include "register/op_registry.h"
  21. using domi::FusionParseParamFunc;
  22. using domi::FusionParseParamByOpFunc;
  23. namespace ge {
  24. Status TensorFlowFusionCustomParserAdapter::ParseParams(const vector<const NodeDef *> &v_input_const,
  25. ge::NodePtr &node) const {
  26. GE_CHECK_NOTNULL(node);
  27. auto op_dest = node->GetOpDesc();
  28. GE_CHECK_NOTNULL(op_dest);
  29. std::vector<const google::protobuf::Message *> inside_nodes;
  30. for (auto inside_node : v_input_const) {
  31. GE_CHECK_NOTNULL(inside_node);
  32. const google::protobuf::Message *node_src = reinterpret_cast<const google::protobuf::Message *>(inside_node);
  33. inside_nodes.push_back(node_src);
  34. }
  35. std::string ori_type = op_dest->GetType();
  36. (void)ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, ori_type);
  37. FusionParseParamFunc
  38. custom_op_parser = domi::OpRegistry::Instance()->GetFusionParseParamFunc(op_dest->GetType(), ori_type);
  39. if (custom_op_parser == nullptr) {
  40. REPORT_CALL_ERROR("E19999", "No FusionParseParamFunc of node:%s(%s) exist in OpRegistry",
  41. node->GetName().c_str(), node->GetType().c_str());
  42. GELOGE(FAILED, "No FusionParseParamFunc of node:%s(%s) exist in OpRegistry",
  43. node->GetName().c_str(), node->GetType().c_str());
  44. return FAILED;
  45. }
  46. GELOGI("Get fusion parser succ, node: %s.", node->GetName().c_str());
  47. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
  48. GE_CHK_BOOL_RET_STATUS(custom_op_parser(inside_nodes, op) == SUCCESS, FAILED,
  49. "Custom parse params failed for node:%s(%s)",
  50. node->GetName().c_str(), node->GetType().c_str());
  51. op.BreakConnect();
  52. GELOGI("Run fusion parser succ, node: %s.", node->GetName().c_str());
  53. return SUCCESS;
  54. }
  55. Status TensorFlowFusionCustomParserAdapter::ParseParams(const std::vector<ge::Operator> &v_input_const,
  56. ge::NodePtr &node) const {
  57. GE_CHECK_NOTNULL(node);
  58. auto op_dest = node->GetOpDesc();
  59. GE_CHECK_NOTNULL(op_dest);
  60. GELOGI("Custom fusion begin to parse params, node: %s.", node->GetName().c_str());
  61. std::string ori_type = op_dest->GetType();
  62. (void)ge::AttrUtils::GetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, ori_type);
  63. FusionParseParamByOpFunc
  64. custom_op_parser = domi::OpRegistry::Instance()->GetFusionParseParamByOpFunc(op_dest->GetType(), ori_type);
  65. if (custom_op_parser == nullptr) {
  66. REPORT_CALL_ERROR("E19999", "No FusionParseParamByOpFunc of node:%s(%s) exist in OpRegistry",
  67. node->GetName().c_str(), node->GetType().c_str());
  68. GELOGE(FAILED, "No FusionParseParamByOpFunc of node:%s(%s) exist in OpRegistry",
  69. node->GetName().c_str(), node->GetType().c_str());
  70. return FAILED;
  71. }
  72. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
  73. GE_CHK_BOOL_RET_STATUS(custom_op_parser(v_input_const, op) == SUCCESS, FAILED,
  74. "Custom parser params failedfor node:%s(%s)",
  75. node->GetName().c_str(), node->GetType().c_str());
  76. for (const auto &op_src : v_input_const) {
  77. op_src.BreakConnect();
  78. }
  79. op.BreakConnect();
  80. GELOGI("Run fusion parser succ, node: %s.", node->GetName().c_str());
  81. return SUCCESS;
  82. }
  83. } // namespace ge