You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_util.h 8.1 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
3 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_
  17. #define OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_
  18. #include <map>
  19. #include <string>
  20. #include <unordered_map>
  21. #include "parser/common/op_def/operator.h"
  22. #include "external/graph/attr_value.h"
  23. #include "external/graph/graph.h"
  24. #include "framework/omg/parser/parser_types.h"
  25. #include "framework/omg/omg_inner_types.h"
  26. #include "graph/compute_graph.h"
  27. #include "graph/ge_tensor.h"
  28. #include "graph/op_desc.h"
  29. #include "graph/utils/attr_utils.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/op_desc_utils.h"
  32. #include "graph/utils/tensor_utils.h"
  33. #include "proto/tensorflow/graph.pb.h"
  34. namespace ge {
  35. /***************************TensorFlow attribute type, constant definition*******************************************/
  36. extern const std::string TENSORFLOW_ATTR_TYPE_STRING;
  37. extern const std::string TENSORFLOW_ATTR_TYPE_INT;
  38. extern const std::string TENSORFLOW_ATTR_TYPE_FLOAT;
  39. extern const std::string TENSORFLOW_ATTR_TYPE_BOOL;
  40. extern const std::string TENSORFLOW_ATTR_TYPE_TYPE;
  41. extern const std::string TENSORFLOW_ATTR_TYPE_SHAPE;
  42. extern const std::string TENSORFLOW_ATTR_TYPE_TENSOR;
  43. extern const std::string TENSORFLOW_ATTR_TYPE_FUNC;
  44. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING;
  45. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_INT;
  46. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT;
  47. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL;
  48. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE;
  49. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE;
  50. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR;
  51. extern const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC;
  52. /***************************constant definition*******************************************/
  53. extern const std::string TENSORFLOW_ATTR_OUTPUT_OP;
  54. extern const std::string TENSORFLOW_ATTR_T;
  55. extern const std::string TENSORFLOW_ATTR_N;
  56. extern const std::string TENSORFLOW_ATTR_DATA_FORMAT;
  57. extern const std::string TENSORFLOW_ATTR_PADDING;
  58. extern const std::string TENSORFLOW_ATTR_KSIZE;
  59. extern const std::string TENSORFLOW_ATTR_STRIDES;
  60. extern const std::string TENSORFLOW_ATTR_DILATIONS;
  61. extern const std::string TENSORFLOW_ATTR_DTYPE;
  62. extern const std::string TENSORFLOW_ATTR_VALUE;
  63. extern const std::string TENSORFLOW_ATTR_TRANSINPUT;
  64. extern const std::string TENSORFLOW_ATTR_TRANSWEIGHT;
  65. extern const std::string TENSORFLOW_ATTR_SHAPE;
  66. extern const std::string TENSORFLOW_ATTR_TIDX;
  67. extern const std::string TENSORFLOW_ATTR_TPADDINGS;
  68. extern const std::string TENSORFLOW_ATTR_TMULTIPLES;
  69. extern const std::string TENSORFLOW_ATTR_TINDICES;
  70. extern const std::string TENSORFLOW_ATTR_TPARAMS;
  71. extern const std::string TENSORFLOW_ATTR_TAXIS;
  72. extern const std::string TENSORFLOW_ATTR_DSTT;
  73. extern const std::string TENSORFLOW_ATTR_SRCT;
  74. extern const std::string TENSORFLOW_ATTR_PERM;
  75. extern const std::string TENSORFLOW_ATTR_INDEX;
  76. extern const std::string TENSORFLOW_ATTR_TSHAPE;
  77. extern const std::string TENSORFLOW_ATTR_AXIS;
  78. extern const std::string TENSORFLOW_ATTR_BIAS;
  79. extern const std::string TENSORFLOW_ATTR_DEPTH_RADIUS;
  80. extern const std::string TENSORFLOW_ATTR_ALPHA;
  81. extern const std::string TENSORFLOW_ATTR_BETA;
  82. extern const std::string TENSORFLOW_ATTR_MODE;
  83. // op:Const
  84. extern const std::string TENSORFLOWF_NODE_OP_CONST;
  85. extern const std::string TENSORFLOWF_NODE_OP_IDENTITY;
  86. extern const std::string TENSORFLOWF_NODE_OP_SWITCH;
  87. extern const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER;
  88. extern const std::string TENSORFLOWF_NODE_OP_ADDN;
  89. extern const std::string TENSORFLOWF_NODE_OP_MATMUL;
  90. extern const std::string TENSORFLOWF_NODE_OP_RELU;
  91. extern const std::string TENSORFLOWF_NODE_OP_SHAPE;
  92. extern const std::string TENSORFLOWF_NODE_OP_TRANSPOSE;
  93. extern const std::string TENSORFLOWF_NODE_OP_MERGE;
  94. // data_format
  95. extern const std::string TENSORFLOWF_TENSOR_NCHW;
  96. extern const std::string TENSORFLOWF_TENSOR_NHWC;
  97. extern const int TENSORFLOW_CONV_STRIDE_NUM;
  98. extern const int TENSORFLOW_CONV_DILATION_NUM;
  99. // padding
  100. extern const std::string TENSORFLOWF_OP_PADDING_VALID;
  101. extern const std::string TENSORFLOWF_OP_PADDING_SAME;
  102. // normal input size
  103. extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL;
  104. extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_RESHAPE;
  105. extern const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_POOL;
  106. // normal weight size
  107. extern const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_MATMUL;
  108. extern const uint32_t TENSORFLOW_NORMAL_WEIGHT_SIZE_RESHAPE;
  109. // input or output
  110. extern const uint32_t TENSORFLOW_NORMAL_INPUT_TENSOR_FLAG;
  111. extern const uint32_t TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG;
  112. class TensorFlowUtil {
  113. public:
  114. /**
  115. * @ingroup domi_omg
  116. * @brief find the corresponding AttrValue in NodeDef
  117. * @param [in] nodeDef nodedef object to find
  118. * @param [in] attr_name attribute name
  119. * @param [out] attr_value attribute value
  120. * @return true attribute exists
  121. * @return false attribute does not exist
  122. *
  123. */
  124. static bool FindAttrValue(const domi::tensorflow::NodeDef *const node_def, const std::string &attr_name,
  125. domi::tensorflow::AttrValue &attr_value);
  126. /**
  127. * @ingroup domi_omg
  128. * @brief Check the actual type and expected type of the AttrValue, int, float, list (int), list (bool), etc.
  129. * @param [in] attr_value attrValue to check
  130. * @param [in] type expected attribute type
  131. * @return SUCCESS success
  132. * @return FAILED failed
  133. *
  134. */
  135. static domi::Status CheckAttrHasType(const domi::tensorflow::AttrValue &attr_value, const std::string &type);
  136. /**
  137. * @ingroup domi_omg
  138. * @brief parsing data types
  139. * @param [in] node_src node to be parsed
  140. * @param [in] attr_src attribute to be parsed
  141. * @param [out] data_type parsed data type
  142. * @return SUCCESS Parsing success
  143. * @return FAILED parsing failed
  144. *
  145. */
  146. static domi::Status ParseDataType(const domi::tensorflow::NodeDef *node_src,
  147. const std::string &attr_src,
  148. domi::tensorflow::DataType &data_type);
  149. /**
  150. * @ingroup domi_omg
  151. * @brief parsing data types
  152. * @param [in] attr_value attr in NodeDef to be converted
  153. * @param [out] op the parsed information is stored in the properties of the parent class
  154. * @return SUCCESS conversion success
  155. * @return FAILED conversion failed
  156. *
  157. */
  158. static domi::Status TransTensorDescriptor(const domi::tensorflow::AttrValue &attr_value,
  159. ParserOperator *const op,
  160. const uint32_t io,
  161. const std::string &type = "");
  162. /*
  163. * @brief 添加NodeDef属性
  164. * @param [in] attr_name attribute name
  165. * @param [in] attr_value attribute Value Object
  166. * @param [out] node_def
  167. * @return void
  168. *
  169. */
  170. static void AddNodeAttr(const std::string &attr_name,
  171. const domi::tensorflow::AttrValue &value,
  172. domi::tensorflow::NodeDef *const node_def);
  173. static domi::Status ClearUnusedParam(ge::ComputeGraphPtr &graph);
  174. static bool ParseFromAttrValueList(ge::GeTensorDesc &ge_desc,
  175. const domi::tensorflow::AttrValue_ListValue &a_list,
  176. int32_t i,
  177. int32_t &tf_datatype);
  178. };
  179. } // namespace ge
  180. #endif // OMG_PARSER_TENSORFLOW_TENSORFLOW_UTIL_H_