You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_reshape_parser.cc 4.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "parser/tensorflow/tensorflow_reshape_parser.h"
  17. #include "framework/common/debug/ge_log.h"
  18. #include "graph/utils/type_utils.h"
  19. #include "parser/common/op_parser_factory.h"
  20. #include "parser/common/util.h"
  21. #include "parser/tensorflow/tensorflow_util.h"
  22. #include "parser/common/acl_graph_parser_util.h"
  23. #include "omg/parser/parser_inner_ctx.h"
  24. using domi::TENSORFLOW;
  25. using namespace ge::parser;
  26. namespace ge {
  27. Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) {
  28. int32_t tf_datatype = 0;
  29. auto a_list = attr_value.list();
  30. GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), PARAM_INVALID,
  31. "parse ge_desc failed.");
  32. uint32_t size_type = 1;
  33. auto data_type = ge_desc.GetDataType();
  34. bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type);
  35. GE_IF_BOOL_EXEC(!type_ret,
  36. REPORT_CALL_ERROR("E19999", "Data type %s is not supported",
  37. ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
  38. GELOGE(FAILED, "Can't GetDataTypeLength of data_type: %s",
  39. ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
  40. return PARAM_INVALID);
  41. // calculate size
  42. int64_t real_size = 1;
  43. for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) {
  44. int64_t tmp_dim = ge_desc.GetShape().GetDim(j);
  45. GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;);
  46. real_size *= tmp_dim;
  47. }
  48. PARSER_INT64_MULCHECK(real_size, size_type);
  49. ge::TensorUtils::SetSize(ge_desc, real_size * size_type);
  50. ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum());
  51. GELOGI("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u",
  52. ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(),
  53. ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type);
  54. return SUCCESS;
  55. }
  56. Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
  57. GE_CHECK_NOTNULL(op_src);
  58. GE_CHECK_NOTNULL(op);
  59. const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
  60. GE_CHECK_NOTNULL(node_src);
  61. GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str());
  62. domi::tensorflow::AttrValue input_attr_value;
  63. domi::tensorflow::AttrValue output_attr_value;
  64. GE_IF_BOOL_EXEC(
  65. GetParserContext().train_flag,
  66. ge::GeTensorDesc input_desc;
  67. ge::GeTensorDesc output_desc;
  68. if (TensorFlowUtil::FindAttrValue(node_src, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) {
  69. GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed");
  70. }
  71. if (TensorFlowUtil::FindAttrValue(node_src, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) {
  72. GE_CHK_BOOL_RET_STATUS(ParseDesc(output_attr_value, output_desc) == SUCCESS, FAILED,
  73. "parse output desc failed");
  74. }
  75. GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED,
  76. "set input desc failed");
  77. GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED,
  78. "set output desc failed"););
  79. return SUCCESS;
  80. }
  81. REGISTER_OP_PARSER_CREATOR(TENSORFLOW, RESHAPE, TensorFlowReshapeParser);
  82. } // namespace ge