You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_auto_mapping_parser_adapter.cc 8.3 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
3 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. /**
  2. * Copyright (c) Huawei Technologies Co., Ltd. 2022. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tensorflow_auto_mapping_parser_adapter.h"
  17. #include "framework/omg/parser/parser_types.h"
  18. #include "common/util.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/def_types.h"
  21. #include "parser/common/op_parser_factory.h"
  22. #include "register/op_registry.h"
  23. #include "register/register.h"
  24. #include "register/register_utils.h"
  25. #include "parser/common/parser_utils.h"
  26. using domi::TENSORFLOW;
  27. using namespace ge::parser;
  28. using ge::parser::PLACEHOLDERWITHDEFAULT;
  29. namespace ge {
  30. namespace {
  31. const char *const kTfAttrT = "T";
  32. const char *const kShapeAttrOutType = "out_type";
  33. const char *const kShapeAttrDtype = "dtype";
  34. } // namespace
  35. Status TensorFlowAutoMappingParserAdapter::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
  36. if (op_src == nullptr) {
  37. REPORT_INNER_ERROR("E19999", "Param op_src is nullptr, check invalid");
  38. GELOGE(PARAM_INVALID, "Op src is null");
  39. return PARAM_INVALID;
  40. }
  41. const domi::tensorflow::NodeDef *node = PtrToPtr<const Message, const domi::tensorflow::NodeDef>(op_src);
  42. GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
  43. if (op_dest == nullptr) {
  44. REPORT_INNER_ERROR("E19999", "Param op_dest is nullptr, check invalid");
  45. GELOGE(FAILED, "Op dest is null");
  46. return PARAM_INVALID;
  47. }
  48. ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
  49. Status ret = domi::OperatorAutoMapping(op_src, op);
  50. if (ret != SUCCESS) {
  51. REPORT_CALL_ERROR("E19999", "call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str());
  52. GELOGE(FAILED, "Tensorflow auto mapping parser params failed");
  53. return FAILED;
  54. }
  55. op.BreakConnect();
  56. if (op_dest->GetType() == EMPTY) {
  57. domi::tensorflow::AttrValue attr;
  58. if (TensorFlowUtil::FindAttrValue(node, kShapeAttrDtype, attr)) {
  59. ge::DataType data_type = domi::TensorAssign::ConvertTensorflowDataType(static_cast<uint32_t>(attr.type()));
  60. AttrUtils::SetInt(op_dest, kShapeAttrDtype, data_type);
  61. GELOGD("Get dtype:%d success.", data_type);
  62. } else {
  63. GELOGW("Get dtype failed!");
  64. }
  65. }
  66. // add dynamic input/output
  67. if (op_dest->GetType() == IDENTITYN) {
  68. uint32_t dynamic_tensor_num = 0;
  69. domi::tensorflow::AttrValue attr_num;
  70. if (!(TensorFlowUtil::FindAttrValue(node, kTfAttrT, attr_num))) {
  71. GELOGW("In NodeDef %s dynamic attr [%s] is not exist.", op_dest->GetName().c_str(), kTfAttrT);
  72. }
  73. dynamic_tensor_num = attr_num.list().type_size();
  74. GE_CHK_STATUS_RET(op_dest->AddDynamicInputDesc("x", dynamic_tensor_num), "AddDynamicInputDesc failed");
  75. GE_CHK_STATUS_RET(op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num), "AddDynamicInputDesc failed");
  76. GELOGI("add dynamic intput and output for op [%s], type[%s], number:%u", op_dest->GetName().c_str(),
  77. op_dest->GetType().c_str(), dynamic_tensor_num);
  78. }
  79. if (op_dest->GetType() == SIZE) {
  80. ge::DataType out_type = DT_INT32;
  81. if (AttrUtils::GetDataType(op_dest, kShapeAttrOutType, out_type)) {
  82. if (!AttrUtils::SetInt(op_dest, kShapeAttrDtype, static_cast<int64_t>(out_type))) {
  83. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kShapeAttrDtype,
  84. op_dest->GetName().c_str(), op_dest->GetType().c_str());
  85. GELOGE(FAILED, "Set attr dtype for op:%s failed.", op_dest->GetName().c_str());
  86. return FAILED;
  87. }
  88. }
  89. }
  90. // add nodedef for shape insert by adapter when online_infer_dynamic
  91. if (op_dest->GetType() == SHAPE) {
  92. ge::DataType out_type = DT_INT32;
  93. if (AttrUtils::GetDataType(op_dest, kShapeAttrOutType, out_type)) {
  94. if (!AttrUtils::SetInt(op_dest, kShapeAttrDtype, static_cast<int64_t>(out_type))) {
  95. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", kShapeAttrDtype,
  96. op_dest->GetName().c_str(), op_dest->GetType().c_str());
  97. GELOGE(FAILED, "Set attr dtype for op:%s failed.", op_dest->GetName().c_str());
  98. return FAILED;
  99. }
  100. }
  101. const auto out_desc = op_dest->MutableOutputDesc(0);
  102. GE_CHECK_NOTNULL(out_desc);
  103. out_desc->SetDataType(out_type);
  104. std::shared_ptr<domi::tensorflow::NodeDef> pkg_node = ge::parser::MakeShared<domi::tensorflow::NodeDef>();
  105. GE_CHECK_NOTNULL(pkg_node);
  106. pkg_node->CopyFrom(*node);
  107. // Get the property opdef, if the property does not exist, return failure
  108. pkg_node->mutable_attr()->erase(ge::ATTR_NAME_FRAMEWORK_OP_DEF);
  109. pkg_node->mutable_attr()->erase(ge::ATTR_NAME_OUTPUT_TENSOR_DESC);
  110. pkg_node->mutable_attr()->erase(ge::ATTR_NAME_INPUT_TENSOR_DESC);
  111. pkg_node->mutable_attr()->erase(ge::VAR_ATTR_NAME);
  112. // Serialize nodedef into string and package as a whole
  113. string serialized_node;
  114. GE_IF_BOOL_EXEC(!pkg_node->SerializeToString(&serialized_node),
  115. REPORT_CALL_ERROR("E19999", "Trans NodeDef:%s(%s) to string failed",
  116. pkg_node->name().c_str(), pkg_node->op().c_str());
  117. GELOGE(PARAM_INVALID, "In FrameworkOp trans NodeDef to string failed.");
  118. return PARAM_INVALID);
  119. (void)AttrUtils::SetZeroCopyBytes(
  120. op_dest, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
  121. Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(serialized_node.data()), serialized_node.length()));
  122. GELOGI("node_def of %s is %s.", op_dest->GetName().c_str(), serialized_node.c_str());
  123. }
  124. return SUCCESS;
  125. }
  126. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, PLACEHOLDERWITHDEFAULT, TensorFlowAutoMappingParserAdapter);
  127. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, EXPANDDIMS, TensorFlowAutoMappingParserAdapter);
  128. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SIZE, TensorFlowAutoMappingParserAdapter);
  129. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPE, TensorFlowAutoMappingParserAdapter);
  130. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, GUARANTEECONST, TensorFlowAutoMappingParserAdapter);
  131. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, BROADCASTARGS, TensorFlowAutoMappingParserAdapter);
  132. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, PREVENTGRADIENT, TensorFlowAutoMappingParserAdapter);
  133. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, RANK, TensorFlowAutoMappingParserAdapter);
  134. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, BROADCASTGRADIENTARGS, TensorFlowAutoMappingParserAdapter);
  135. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, STOPGRADIENT, TensorFlowAutoMappingParserAdapter);
  136. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, DESTROYTEMPORARYVARIABLE, TensorFlowAutoMappingParserAdapter);
  137. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SNAPSHOT, TensorFlowAutoMappingParserAdapter);
  138. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, EMPTY, TensorFlowAutoMappingParserAdapter);
  139. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITYN, TensorFlowAutoMappingParserAdapter);
  140. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, CONTROLTRIGGER, TensorFlowAutoMappingParserAdapter);
  141. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SWITCH, TensorFlowAutoMappingParserAdapter);
  142. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, LOOPCOND, TensorFlowAutoMappingParserAdapter);
  143. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, NEXTITERATION, TensorFlowAutoMappingParserAdapter);
  144. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFNEXTITERATION, TensorFlowAutoMappingParserAdapter);
  145. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, EXIT, TensorFlowAutoMappingParserAdapter);
  146. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, REFEXIT, TensorFlowAutoMappingParserAdapter);
  147. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, CONSTANT, TensorFlowAutoMappingParserAdapter);
  148. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, PARALLELCONCATSTART, TensorFlowAutoMappingParserAdapter);
  149. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, BITCAST, TensorFlowAutoMappingParserAdapter);
  150. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, IDENTITY, TensorFlowAutoMappingParserAdapter);
  151. }