You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_shape_n_parser.cc 7.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parser/tensorflow/tensorflow_shape_n_parser.h"
  17. #include "parser/common/op_def/ir_pb_converter.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "parser/common/op_parser_factory.h"
  20. #include "parser/common/op_def/shape_n_op.h"
  21. #include "parser/common/util.h"
  22. using domi::TENSORFLOW;
  23. using domi::tensorflow::AttrValue;
  24. using domi::tensorflow::DataType;
  25. using domi::tensorflow::DT_FLOAT;
  26. using domi::tensorflow::DT_INT32;
  27. using namespace ge::parser;
  28. namespace {
  29. const std::string kShapeAttrDtype = "out_type";
  30. } // namespace
  31. namespace ge {
  32. Status TensorFlowShapeNParser::ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
  33. // The upper caller guarantees the input params is not empty.
  34. domi::tensorflow::AttrValue attr;
  35. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_T, attr),
  36. op->InType(domi::TensorAssign::ConvertTensorflowDataType(DT_FLOAT));
  37. return SUCCESS);
  38. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed");
  39. domi::tensorflow::DataType tf_type = attr.type();
  40. ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type);
  41. CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED,
  42. REPORT_CALL_ERROR("E19999", "Data type %s of node %s is not supported",
  43. DataType_Name(tf_type).c_str(), node->name().c_str());
  44. GELOGE(FAILED, "Data type %s of node %s is not supported.",
  45. DataType_Name(tf_type).c_str(), node->name().c_str());
  46. return PARAM_INVALID);
  47. op->InType(type);
  48. return SUCCESS;
  49. }
  50. Status TensorFlowShapeNParser::ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
  51. // The upper caller guarantees the input params is not empty.
  52. domi::tensorflow::AttrValue attr;
  53. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, kShapeAttrDtype, attr),
  54. op->OutType(domi::TensorAssign::ConvertTensorflowDataType(DT_INT32));
  55. return SUCCESS);
  56. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "type"), "check Attr T failed");
  57. domi::tensorflow::DataType tf_type = attr.type();
  58. ge::DataType type = domi::TensorAssign::ConvertTensorflowDataType(tf_type);
  59. CHECK_FALSE_EXEC(type != ge::DataType::DT_UNDEFINED,
  60. REPORT_CALL_ERROR("E19999", "Data type %s of node %s is not supported",
  61. DataType_Name(tf_type).c_str(), node->name().c_str());
  62. GELOGE(FAILED, "Data type %s of node %s is not supported.",
  63. DataType_Name(tf_type).c_str(), node->name().c_str());
  64. return PARAM_INVALID);
  65. op->OutType(type);
  66. return SUCCESS;
  67. }
  68. Status TensorFlowShapeNParser::ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
  69. // The upper caller guarantees the input params is not empty.
  70. domi::tensorflow::AttrValue attr;
  71. const int64_t attr_n = 2;
  72. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr), op->N(attr_n); return SUCCESS);
  73. GE_RETURN_WITH_LOG_IF_ERROR(TensorFlowUtil::CheckAttrHasType(attr, "int"), "check Attr N failed");
  74. op->N(attr.i());
  75. return SUCCESS;
  76. }
  77. Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
  78. GE_CHECK_NOTNULL(op_dest);
  79. const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
  80. GE_CHECK_NOTNULL(node);
  81. ShapeNOperator op;
  82. op.Name(node->name());
  83. GE_RETURN_IF_ERROR(PreParseParams(node, &op));
  84. GE_RETURN_WITH_LOG_IF_ERROR(ParseInType(node, &op), "Parse in type for node %s failed.", node->name().c_str());
  85. GE_RETURN_WITH_LOG_IF_ERROR(ParseN(node, &op), "Parse N for node %s failed.", node->name().c_str());
  86. GE_RETURN_WITH_LOG_IF_ERROR(ParseOutType(node, &op), "Parse out type for node %s failed.", node->name().c_str());
  87. GE_RETURN_IF_ERROR(PostParseParams(node, &op));
  88. // add dynamic input/output
  89. domi::tensorflow::AttrValue attr_num;
  90. CHECK_FALSE_EXEC(TensorFlowUtil::FindAttrValue(node, SHAPEN_ATTR_N, attr_num),
  91. REPORT_CALL_ERROR("E19999", "In NodeDef:%s attr:%s not exist, check invalid",
  92. node->name().c_str(), SHAPEN_ATTR_N.c_str());
  93. GELOGE(FAILED, "Get Attr N failed in Node %s.", node->name().c_str());
  94. return PARAM_INVALID);
  95. int32_t dynamic_tensor_num = attr_num.i();
  96. Status ret;
  97. domi::tensorflow::AttrValue output_attr_value;
  98. if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) {
  99. GE_CHK_STATUS_RET(
  100. TensorFlowUtil::TransTensorDescriptor(output_attr_value, &op, TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG),
  101. "trans output_attr_value failed, op: %s", node->name().c_str());
  102. ret = ConvertToOpDesc(op, op_dest);
  103. if (ret != SUCCESS) {
  104. return ret;
  105. }
  106. } else {
  107. ret = ConvertToOpDesc(op, op_dest);
  108. if (ret != SUCCESS) {
  109. return ret;
  110. }
  111. graphStatus status = op_dest->AddDynamicOutputDesc("y", dynamic_tensor_num);
  112. if (status != GRAPH_SUCCESS) {
  113. REPORT_CALL_ERROR("E19999", "Add Dynamic OuputDesc name:y to node:%s(%s) failed",
  114. op_dest->GetName().c_str(), op_dest->GetType().c_str());
  115. GELOGE(FAILED, "Add dynamic output:y for node:%s failed.", op_dest->GetName().c_str());
  116. return FAILED;
  117. }
  118. }
  119. graphStatus status = op_dest->AddDynamicInputDesc("x", dynamic_tensor_num);
  120. if (status != GRAPH_SUCCESS) {
  121. REPORT_CALL_ERROR("E19999", "Add Dynamic InputDesc name:x to node:%s(%s) failed",
  122. op_dest->GetName().c_str(), op_dest->GetType().c_str());
  123. GELOGE(FAILED, "Add dynamic input:x for node:%s failed.", op_dest->GetName().c_str());
  124. return FAILED;
  125. }
  126. GELOGI("add dynamic input and output for op [%s], type[%s], name: %s, number:%d", op_dest->GetName().c_str(),
  127. op_dest->GetType().c_str(), SHAPEN_ATTR_N.c_str(), dynamic_tensor_num);
  128. return SUCCESS;
  129. }
  130. // AUTO GEN PLEASE DO NOT MODIFY IT
  131. Status TensorFlowShapeNParser::PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
  132. return SUCCESS;
  133. }
  134. Status TensorFlowShapeNParser::PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op) {
  135. return SUCCESS;
  136. }
  137. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SHAPEN, TensorFlowShapeNParser);
  138. } // namespace ge