You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

acl_graph_parser_util.h 9.5 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright (c) Huawei Technologies Co., Ltd. 2020. All rights reserved.
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #ifndef ACL_GRAPH_PARSE_UTIL_
  14. #define ACL_GRAPH_PARSE_UTIL_
  15. #include <google/protobuf/text_format.h>
  16. #include <map>
  17. #include <sstream>
  18. #include <string>
  19. #include <unordered_map>
  20. #include <vector>
  21. #include "framework/omg/parser/parser_types.h"
  22. #include "graph/ascend_string.h"
  23. #include "graph/utils/graph_utils.h"
  24. #include "register/register_error_codes.h"
  25. namespace ge {
  26. using google::protobuf::Message;
  27. class AclGraphParseUtil {
  28. public:
  29. AclGraphParseUtil() {}
  30. virtual ~AclGraphParseUtil() {}
  31. static domi::Status LoadOpsProtoLib();
  32. static void SaveCustomCaffeProtoPath();
  33. domi::Status AclParserInitialize(const std::map<std::string, std::string> &options);
  34. domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params) const;
  35. domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params,
  36. std::string &graph_name) const;
  37. domi::Status ParseParamsAfterGraph(ge::Graph &graph, const std::map<AscendString,
  38. AscendString> &parser_params) const;
  39. private:
  40. bool parser_initialized = false;
  41. domi::Status CheckOptions(const std::map<AscendString, AscendString> &parser_params) const;
  42. domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const;
  43. void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info,
  44. std::vector<std::string> &output_nodes_name) const;
  45. static void SetDefaultFormat();
  46. domi::Status ParseAclOutputNodes(const std::string &out_nodes) const;
  47. domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16) const;
  48. domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes) const;
  49. static void AddAttrsForInputNodes(const vector<string> &adjust_fp16_format_vec, const string &fp16_nodes_name,
  50. size_t index, OpDescPtr &op_desc);
  51. domi::Status ParseAclInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes,
  52. const string &is_input_adjust_hw_layout) const;
  53. domi::Status SetSpecifyIndexAttrByInputNames(const ComputeGraphPtr &graph, const std::string &input_data_names) const;
  54. domi::Status GetDefaultOutInfo(ge::ComputeGraphPtr &compute_graph,
  55. std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) const;
  56. };
  57. namespace parser {
  58. /// @ingroup: domi_common
  59. /// @brief: get length of file
  60. /// @param [in] input_file: path of file
  61. /// @return long: File length. If the file length fails to be obtained, the value -1 is returned.
  62. extern long GetFileLength(const std::string &input_file);
  63. /// @ingroup domi_common
  64. /// @brief Absolute path for obtaining files.
  65. /// @param [in] path of input file
  66. /// @param [out] Absolute path of a file. If the absolute path cannot be obtained, an empty string is returned
  67. std::string RealPath(const char *path);
  68. /// @ingroup domi_common
  69. /// @brief Obtains the absolute time (timestamp) of the current system.
  70. /// @return Timestamp, in microseconds (US)
  71. uint64_t GetCurrentTimestamp();
  72. /// @ingroup domi_common
  73. /// @brief Reads all data from a binary file.
  74. /// @param [in] file_name path of file
  75. /// @param [out] buffer Output memory address, which needs to be released by the caller.
  76. /// @param [out] length Output memory size
  77. /// @return false fail
  78. /// @return true success
  79. bool ReadBytesFromBinaryFile(const char *file_name, char **buffer, int &length);
  80. /// @ingroup domi_common
  81. /// @brief proto file in bianary format
  82. /// @param [in] file path of proto file
  83. /// @param [out] proto memory for storing the proto file
  84. /// @return true success
  85. /// @return false fail
  86. bool ReadProtoFromBinaryFile(const char *file, Message *proto);
  87. /// @ingroup domi_common
  88. /// @brief Reads the proto structure from an array.
  89. /// @param [in] data proto data to be read
  90. /// @param [in] size proto data size
  91. /// @param [out] proto Memory for storing the proto file
  92. /// @return true success
  93. /// @return false fail
  94. bool ReadProtoFromArray(const void *data, int size, Message *proto);
  95. /// @ingroup domi_proto
  96. /// @brief Reads the proto file in the text format.
  97. /// @param [in] file path of proto file
  98. /// @param [out] message Memory for storing the proto file
  99. /// @return true success
  100. /// @return false fail
  101. bool ReadProtoFromText(const char *file, google::protobuf::Message *message);
  102. bool ReadProtoFromMem(const char *data, int size, google::protobuf::Message *message);
  103. /// @brief get the Original Type of FrameworkOp
  104. /// @param [in] node
  105. /// @param [out] type
  106. /// @return Status
  107. domi::Status GetOriginalType(const ge::NodePtr &node, string &type);
  108. /// @ingroup domi_common
  109. /// @brief Check whether the file path meets the whitelist verification requirements.
  110. /// @param [in] filePath file path
  111. /// @param [out] result
  112. bool ValidateStr(const std::string &filePath, const std::string &mode);
  113. /// @ingroup domi_common
  114. /// @brief Obtains the current time string.
  115. /// @return Time character string in the format: %Y%m%d%H%M%S, eg: 20171011083555
  116. std::string CurrentTimeInStr();
  117. template <typename T, typename... Args>
  118. inline std::shared_ptr<T> MakeShared(Args &&... args) {
  119. using T_nc = typename std::remove_const<T>::type;
  120. std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...));
  121. return ret;
  122. }
  123. /// @ingroup math_util
  124. /// @brief check whether int64 multiplication can result in overflow
  125. /// @param [in] a multiplicator
  126. /// @param [in] b multiplicator
  127. /// @return Status
  128. inline domi::Status Int64MulCheckOverflow(int64_t a, int64_t b) {
  129. if (a > 0) {
  130. if (b > 0) {
  131. if (a > (INT64_MAX / b)) {
  132. return domi::FAILED;
  133. }
  134. } else {
  135. if (b < (INT64_MIN / a)) {
  136. return domi::FAILED;
  137. }
  138. }
  139. } else {
  140. if (b > 0) {
  141. if (a < (INT64_MIN / b)) {
  142. return domi::FAILED;
  143. }
  144. } else {
  145. if ((a != 0) && (b < (INT64_MAX / a))) {
  146. return domi::FAILED;
  147. }
  148. }
  149. }
  150. return domi::SUCCESS;
  151. }
  152. /// @ingroup math_util
  153. /// @brief check whether int64 multiplication can result in overflow
  154. /// @param [in] a multiplicator
  155. /// @param [in] b multiplicator
  156. /// @return Status
  157. inline domi::Status CheckInt64Uint32MulOverflow(int64_t a, uint32_t b) {
  158. if (a == 0 || b == 0) {
  159. return domi::SUCCESS;
  160. }
  161. if (a > 0) {
  162. if (a > (INT64_MAX / b)) {
  163. return domi::FAILED;
  164. }
  165. } else {
  166. if (a < (INT64_MIN / b)) {
  167. return domi::FAILED;
  168. }
  169. }
  170. return domi::SUCCESS;
  171. }
  172. #define PARSER_INT64_MULCHECK(a, b) \
  173. if (ge::parser::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \
  174. GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \
  175. static_cast<int64_t>(b)); \
  176. return INTERNAL_ERROR; \
  177. }
  178. #define PARSER_INT64_UINT32_MULCHECK(a, b) \
  179. if (ge::parser::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \
  180. GELOGW("Int64 %ld and Uint32 %u multiplication can result in overflow!", static_cast<uint64_t>(a), \
  181. static_cast<uint32_t>(b)); \
  182. return INTERNAL_ERROR; \
  183. }
  184. } // namespace parser
  185. } // namespace ge
  186. /*lint --emacro((773),GE_TIMESTAMP_START)*/
  187. /*lint -esym(773,GE_TIMESTAMP_START)*/
  188. #define PARSER_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::parser::GetCurrentTimestamp()
  189. #define PARSER_TIMESTAMP_END(stage, stage_name) \
  190. do { \
  191. uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \
  192. GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
  193. (endUsec_##stage - startUsec_##stage)); \
  194. } while (0);
  195. #define PARSER_TIMESTAMP_EVENT_END(stage, stage_name) \
  196. do { \
  197. uint64_t endUsec_##stage = ge::parser::GetCurrentTimestamp(); \
  198. GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \
  199. (endUsec_##stage - startUsec_##stage)); \
  200. } while (0);
  201. #endif // ACL_GRAPH_PARSE_UTIL_