You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_squeeze_parser.cc 5.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parser/tensorflow/tensorflow_squeeze_parser.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "framework/common/debug/ge_log.h"
  20. #include "common/util.h"
  21. #include "framework/omg/parser/parser_inner_ctx.h"
  22. #include "graph/utils/type_utils.h"
  23. #include "parser/common/op_parser_factory.h"
  24. #include "parser/common/acl_graph_parser_util.h"
  25. using domi::tensorflow::AttrValue;
  26. using std::vector;
  27. using std::shared_ptr;
  28. using domi::TENSORFLOW;
  29. using namespace ge::parser;
  30. namespace ge {
  31. Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) {
  32. int32_t tf_datatype = 0;
  33. auto a_list = attr_value.list();
  34. GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID,
  35. "parse ge_desc failed.");
  36. uint32_t size_type;
  37. int64_t real_size = 1;
  38. int64_t tmp_dim = 0;
  39. auto data_type = ge_desc.GetDataType();
  40. bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
  41. GE_IF_BOOL_EXEC(!type_ret, GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s",
  42. ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
  43. return domi::PARAM_INVALID);
  44. // calculate size
  45. for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
  46. tmp_dim = ge_desc.GetShape().GetDim(j);
  47. GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
  48. PARSER_INT64_MULCHECK(real_size, tmp_dim);
  49. real_size *= tmp_dim;
  50. }
  51. PARSER_INT64_MULCHECK(real_size, size_type);
  52. ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
  53. ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
  54. GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
  55. ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(),
  56. ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type);
  57. return SUCCESS;
  58. }
  59. Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
  60. GE_CHECK_NOTNULL(op_src);
  61. GE_CHECK_NOTNULL(op);
  62. const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src);
  63. GE_CHECK_NOTNULL(node);
  64. GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
  65. bool has_axis = true;
  66. bool has_dims = true;
  67. domi::tensorflow::AttrValue axis;
  68. domi::tensorflow::AttrValue dims;
  69. if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) {
  70. has_axis = false;
  71. }
  72. if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) {
  73. has_dims = false;
  74. }
  75. if (!has_axis && !has_dims) {
  76. return SUCCESS;
  77. }
  78. if (has_axis && has_dims) {
  79. GELOGE(FAILED, "In NodeDef %s dim and axis is error.", node->name().c_str());
  80. return domi::PARAM_INVALID;
  81. }
  82. domi::tensorflow::AttrValue_ListValue values;
  83. if (has_axis) {
  84. values = axis.list();
  85. } else {
  86. values = dims.list();
  87. }
  88. int i = 0;
  89. int size = values.i_size();
  90. vector<int32_t> v_result;
  91. for (i = 0; i < size; i++) {
  92. int32_t result = values.i(i);
  93. v_result.push_back(result);
  94. }
  95. if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) {
  96. GELOGE(FAILED, "Set squeeze axis attr failed");
  97. return FAILED;
  98. }
  99. domi::tensorflow::AttrValue input_attr_value;
  100. domi::tensorflow::AttrValue output_attr_value;
  101. GE_IF_BOOL_EXEC(
  102. GetParserContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc;
  103. if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) {
  104. GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed");
  105. }
  106. if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) {
  107. GE_CHK_BOOL_RET_STATUS(ParseDesc(output_attr_value, output_desc) == SUCCESS, FAILED,
  108. "parse output desc failed");
  109. }
  110. if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) {
  111. GELOGE(FAILED, "Set input desc failed");
  112. return FAILED;
  113. } if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) {
  114. GELOGE(FAILED, "Set output desc failed");
  115. return FAILED;
  116. })
  117. return SUCCESS;
  118. }
  119. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SQUEEZE, TensorFlowSqueezeParser);
  120. } // namespace ge