You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

caffe_parser.h 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PARSER_CAFFE_CAFFE_PARSER_H_
  17. #define PARSER_CAFFE_CAFFE_PARSER_H_
  18. #if defined(_MSC_VER)
  19. #ifdef FUNC_VISIBILITY
  20. #define PARSER_FUNC_VISIBILITY _declspec(dllexport)
  21. #else
  22. #define PARSER_FUNC_VISIBILITY
  23. #endif
  24. #else
  25. #ifdef FUNC_VISIBILITY
  26. #define PARSER_FUNC_VISIBILITY __attribute__((visibility("default")))
  27. #else
  28. #define PARSER_FUNC_VISIBILITY
  29. #endif
  30. #endif
  31. #include <map>
  32. #include <set>
  33. #include <string>
  34. #include <unordered_map>
  35. #include <utility>
  36. #include <vector>
  37. #include "omg/parser/op_parser.h"
  38. #include "omg/parser/model_parser.h"
  39. #include "omg/parser/weights_parser.h"
  40. #include "proto/caffe/caffe.pb.h"
  41. #include "proto/om.pb.h"
  42. namespace ge {
  43. using domi::caffe::NetParameter;
  44. using std::map;
  45. using std::set;
  46. using std::string;
  47. using std::unordered_map;
  48. using std::vector;
  49. static std::map<std::vector<std::string>, std::vector<std::string>> params_share_map;
  50. class PARSER_FUNC_VISIBILITY CaffeModelParser : public domi::ModelParser {
  51. public:
  52. CaffeModelParser() {}
  53. virtual ~CaffeModelParser() {}
  54. /**
  55. * @ingroup domi_omg
  56. * @brief Parse the relevant data from the model file and save it to graph
  57. * @param [in] file Path of model file
  58. * @param [in|out] graph graph for saving model information
  59. * @return SUCCESS parse successfully
  60. * @return FAILED parse failed
  61. */
  62. Status Parse(const char *file, ge::Graph &graph) override;
  63. /**
  64. * @ingroup domi_omg
  65. * @brief Parse the relevant data from memory and save it to graph
  66. * @param [in] memory buffer of model file
  67. * @param [in] buffer size
  68. * @param [in|out] graph graph for saving model information
  69. * @return SUCCESS parse successfully
  70. * @return FAILED parse failed
  71. */
  72. Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
  73. Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
  74. (void)data;
  75. (void)size;
  76. (void)graph;
  77. return domi::SUCCESS;
  78. }
  79. /**
  80. * @ingroup domi_omg
  81. * @brief Convert model files to JSON format
  82. * @param [in] model_file Path of model file
  83. * @param [out] json_file Converted JSON file path
  84. * @return SUCCESS parse successfully
  85. * @return others parse failed
  86. */
  87. Status ToJson(const char *model_file, const char *json_file) override;
  88. /**
  89. * @ingroup domi_omg
  90. * @brief Parse the relevant data from the model file and save it to graph
  91. * @param [in] graph_def input tensorflow model
  92. * @param [in|out] graph graph for saving model information
  93. * @return SUCCESS parse successfully
  94. * @return FAILED parse failed
  95. */
  96. Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override;
  97. Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto, domi::GetGraphCallback callback,
  98. ge::ComputeGraphPtr &graph) override;
  99. /*
  100. * @ingroup domi_omg
  101. * @brief Mapping CAFFE's datatype to GE's datatype
  102. * @param [in] type, datatype types of operators in CAFFE networks
  103. * @return ge::DataType
  104. */
  105. ge::DataType ConvertToGeDataType(const uint32_t type) override {
  106. (void)type;
  107. return ge::DT_FLOAT;
  108. }
  109. Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override {
  110. (void)root_proto;
  111. (void)root_graph;
  112. return domi::SUCCESS;
  113. }
  114. private:
  115. Status Parse(const char *file, ge::ComputeGraphPtr &graph);
  116. /**
  117. * @ingroup domi_omg
  118. * @brief Add the Layer in the model to the PreChecker
  119. * @param [in] net caffe net information
  120. * @return SUCCESS build successfully
  121. * @return FAILED build failed
  122. */
  123. Status PreCheck(const domi::caffe::NetParameter &net);
  124. /**
  125. * @ingroup domi_omg
  126. * @brief Parsing input related information from model files
  127. * @param [in] proto_message caffe net information
  128. * @param [in|out] net_input_name Used to store the acquired input name information
  129. * @param [in|out] net_input_data Used to store the acquired input data information
  130. * @return SUCCESS build successfully
  131. * @return FAILED build failed
  132. */
  133. Status ParseInput(domi::caffe::NetParameter &proto_message, bool &input_data_flag);
  134. /*
  135. * @ingroup domi_omg
  136. * @brief Parse model by custom proto and save info to operators
  137. * @param [in] model_path, file path of model(prototxt file)
  138. * @param [in] custom_proto, file path of custom proto
  139. * @param [in] caffe_proto, file path of caffe proto
  140. * @param [out] operators, operators saving custom info
  141. * @return SUCCESS parse successfully
  142. * @return FAILED parse failed
  143. */
  144. Status CustomProtoParse(const char *model_path, const string &custom_proto, const string &caffe_proto,
  145. std::vector<ge::Operator> &operators);
  146. /*
  147. * @ingroup domi_omg
  148. * @brief Parse model by custom proto and save info to operators
  149. * @param [in] model_path, file path of model(prototxt file)
  150. * @param [in] custom_proto_path, file path of custom proto
  151. * @param [in] custom_proto_name, custom proto name
  152. * @param [out] operators, operators saving custom info
  153. * @return SUCCESS parse successfully
  154. * @return FAILED parse failed
  155. */
  156. Status ParseNetModelByCustomProto(const char *model_path, const string &custom_proto_path,
  157. const string &custom_proto_name, std::vector<ge::Operator> &operators);
  158. /*
  159. * @ingroup domi_omg
  160. * @brief Read caffe model and shield google warning
  161. * @param [in] model_path, file path of model(prototxt file)
  162. * @param [out] message, message saving custom info
  163. * @return SUCCESS read file successfully
  164. * @return FAILED read file failed
  165. */
  166. Status ReadModelWithoutWarning(const char *model_path, google::protobuf::Message *message);
  167. /*
  168. * @ingroup domi_omg
  169. * @brief Read caffe model and save it to message
  170. * @param [in] model_path, file path of model(prototxt file)
  171. * @param [out] message, message saving custom info
  172. * @return SUCCESS read file successfully
  173. * @return FAILED read file failed
  174. */
  175. Status ReadCaffeModelFromText(const char *model_path, google::protobuf::Message *message);
  176. /*
  177. * @ingroup domi_omg
  178. * @brief Parse layer message and save custom info to operators
  179. * @param [in] layer_descriptor, layer description of message
  180. * @param [in] message, message of model
  181. * @param [out] operators, operators saving custom info
  182. * @return SUCCESS parse layer successfully
  183. * @return FAILED parse layer failed
  184. */
  185. Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
  186. const google::protobuf::Message *message, std::vector<ge::Operator> &operators);
  187. /*
  188. * @ingroup domi_omg
  189. * @brief Create custom operator by op_name and op_type
  190. * @param [in] op_name, name of operator
  191. * @param [in] op_type, type of operator
  192. * @param [in] message, message of model
  193. * @param [in] index, index of field
  194. * @param [out] operators, operators saving custom info
  195. * @return SUCCESS create operator successfully
  196. * @return FAILED create operator failed
  197. */
  198. Status CreateCustomOperator(std::string op_name, std::string op_type, const google::protobuf::Message *message,
  199. int index, std::vector<ge::Operator> &operators);
  200. /**
  201. * @ingroup domi_omg
  202. * @brief Add blob information to the bottom_blobs_map and top_blobs_map_
  203. * @param [in] layer layer information
  204. * @param [in|out] inplace_blob_name_remapping save blob information
  205. * @return Status
  206. */
  207. Status AddBlobsToMap(const domi::caffe::LayerParameter &layer,
  208. std::map<std::string, std::string> &inplace_blob_name_remapping);
  209. /**
  210. * @ingroup domi_omg
  211. * @brief Add node information to graph
  212. * @param [in] layer layer infromation
  213. * @param [in|out] graph graph for saving model information
  214. * @return SUCCESS add successfully
  215. * @return FAILED add failed
  216. */
  217. Status AddNode(const domi::caffe::LayerParameter &layer, ge::ComputeGraphPtr &graph);
  218. /**
  219. * @ingroup domi_omg
  220. * @brief Add edge information to graph
  221. * @param [in|out] graph graph for saving model information
  222. * @return SUCCESS add successfully
  223. * @return FAILED add failed
  224. */
  225. Status AddEdges(ge::ComputeGraphPtr &graph);
  226. /**
  227. * @ingroup domi_omg
  228. * @brief Add top name information to graph
  229. * @param [in|out] proto_message
  230. * @return SUCCESS add successfully
  231. * @return FAILED add failed
  232. */
  233. Status AddOutputTop(const domi::caffe::NetParameter &proto_message);
  234. /**
  235. * @ingroup domi_omg
  236. * @brief Check if the current layer is valid
  237. * @return true valid
  238. * @return false invalid
  239. */
  240. bool CheckValidLayer(const domi::caffe::LayerParameter &layer);
  241. /**
  242. * @ingroup domi_omg
  243. * @brief Check whether the top of the current layer is 'Inplace'
  244. * @return true is 'Inplace'
  245. * @return false not is 'Inplace'
  246. */
  247. bool IsInplaceTopBlob(const domi::caffe::LayerParameter &layer, const std::string &top_name);
  248. /**
  249. * @ingroup domi_omg
  250. * @brief Check whether the top of the current layer is user's specified output top
  251. * @return true yes
  252. * @return false no
  253. */
  254. bool IsOutputTop(const string &op_name, int32_t index);
  255. /**
  256. * @ingroup domi_omg
  257. * @brief Find a layer set with the same param
  258. * @param [in] Param name set of each layer
  259. * @param [in|out] Layer set of the same param
  260. * @return Status
  261. */
  262. Status FindShareParamLayers(const std::map<std::string, std::vector<std::string>> &);
  263. Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer);
  264. Status AddTensorDescToOpDescByIr(ge::OpDescPtr &op_desc, const domi::caffe::LayerParameter &layer,
  265. const string &op_type);
  266. Status AddUserOutNodesTop();
  267. std::string RemapTopNameByLayer(const domi::caffe::LayerParameter &layer, const std::string &top_name, int index);
  268. Status GetCustomOp(const domi::caffe::LayerParameter &layer, vector<ge::Operator> &operators);
  269. bool IsOpAttrEmpty(const ge::Operator &op, const std::string &type);
  270. Status ParseOpParam(const domi::caffe::LayerParameter &layer, ge::OpDescPtr &op,
  271. std::shared_ptr<ge::OpParser> &op_parser);
  272. void SaveOrigionLayerTops(domi::caffe::LayerParameter &layer);
  273. Status ReorderInput(domi::caffe::NetParameter &net);
  274. void AddOutputInfoToContext(string layer_name, int32_t top_index);
  275. Status ParseOutputNodeTopInfo(const domi::caffe::NetParameter &proto_message);
  276. Status SaveDataLayerTops(const domi::caffe::LayerParameter &layer);
  277. std::map<std::string, ge::NodePtr> node_map;
  278. // key: blob name, value: layer name and index
  279. std::map<std::string, std::vector<std::pair<std::string, int32_t>>> bottom_blobs_map_;
  280. // key: blob name, value: layer name and index
  281. std::map<std::string, std::vector<std::pair<std::string, int32_t>>> top_blobs_map_;
  282. std::vector<ge::Operator> custom_operator_;
  283. std::map<std::string, std::vector<std::string>> layer_tops_map_;
  284. };
  285. /**
  286. * @ingroup domi_omg
  287. * @brief Caffe weight parser
  288. */
  289. class PARSER_FUNC_VISIBILITY CaffeWeightsParser : public domi::WeightsParser {
  290. public:
  291. /**
  292. * @ingroup domi_omg
  293. * @brief Parse weight data from file and save to graph
  294. * @param [in] file Path of weight file after training
  295. * @param [in|out] graph Save weight information after parsing
  296. * @return SUCCESS parse successfully
  297. * @return PARAM_INVALID param invalid
  298. * @return PARSE_WEIGHTS_FAILED parse failed
  299. */
  300. Status Parse(const char *file, ge::Graph &graph) override;
  301. Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
  302. private:
  303. Status CheckNodes(ge::ComputeGraphPtr &graph);
  304. /**
  305. * @ingroup domi_omg
  306. * @brief Convert netparameter to modedef and save in graph
  307. * @param [in] param Caffe network parameters to be converted
  308. * @param [in|out] graph Save weight information after parsing
  309. * @return SUCCESS parse successfully
  310. * @return FAILED parse failed
  311. */
  312. static Status ConvertNetParameter(const NetParameter &param, ge::ComputeGraphPtr &graph);
  313. Status Parse(const char *file, ge::ComputeGraphPtr &graph);
  314. Status ParseWeightByFusionProto(const char *model_path, const string &custom_proto_path,
  315. const string &custom_proto_name, ge::ComputeGraphPtr &graph);
  316. Status ParseLayerParameter(const google::protobuf::Descriptor *layer_descriptor,
  317. const google::protobuf::Message *message,
  318. ge::ComputeGraphPtr &graph);
  319. Status ConvertLayerParameter(const google::protobuf::Message *layer_message,
  320. ge::ComputeGraphPtr &graph);
  321. Status CheckLayersSize(const google::protobuf::Message *message);
  322. Status ConvertLayerProto(const google::protobuf::Message *message,
  323. google::protobuf::Message *layer);
  324. Status ParseLayerField(const google::protobuf::Reflection *reflection,
  325. const google::protobuf::Message *message,
  326. const google::protobuf::FieldDescriptor *field,
  327. google::protobuf::Message *layer);
  328. Status ConvertBlobsProto(const google::protobuf::Message *message,
  329. google::protobuf::Message *blobs);
  330. Status ConvertBlobShapeProto(const google::protobuf::Message *message,
  331. google::protobuf::Message *dest_message);
  332. Status ConvertInnerProdcutProto(const google::protobuf::Message *message,
  333. google::protobuf::Message *dest_message);
  334. Status ConvertConvParamProto(const google::protobuf::Message *message,
  335. google::protobuf::Message *dest_message);
  336. /**
  337. * @ingroup domi_omg
  338. * @brief Layer types to be ignored in weight resolution
  339. */
  340. static const set<string> skiped_layer_type_;
  341. std::map<std::string, int32_t> layer_name_record_map_;
  342. };
  343. } // namespace domi
  344. #endif // PARSER_CAFFE_CAFFE_PARSER_H_