You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_parser.h 27 kB

5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_
  17. #define PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_
  18. #include <map>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <vector>
  25. #include "graph/compute_graph.h"
  26. #include "graph/ge_attr_value.h"
  27. #include "graph/ge_tensor.h"
  28. #include "graph/op_desc.h"
  29. #include "graph/operator.h"
  30. #include "graph/range_vistor.h"
  31. #include "graph/utils/attr_utils.h"
  32. #include "graph/utils/tensor_utils.h"
  33. #include "omg/parser/model_parser.h"
  34. #include "omg/parser/op_parser.h"
  35. #include "omg/parser/weights_parser.h"
  36. #include "common/pre_checker.h"
  37. #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
  38. #include "parser/tensorflow/tensorflow_util.h"
  39. #include "proto/om.pb.h"
  40. #include "proto/tensorflow/graph.pb.h"
  41. #include "proto/tensorflow/node_def.pb.h"
  42. #include "proto/tensorflow/graph_library.pb.h"
  43. #include "register/scope/scope_graph_impl.h"
  44. #include "external/register/scope/scope_fusion_pass_register.h"
  45. #include "scope/scope_pass_manager.h"
  46. #include "common/parser_utils.h"
  47. namespace ge {
  48. using std::string;
  49. using std::vector;
  50. using std::set;
  51. using std::map;
  52. using std::unordered_map;
  53. using std::mutex;
  54. using std::shared_ptr;
  55. enum TfTranspose { TO_NCHW, TO_NHWC, NO_TRANSPOSE };
  56. struct OpNodeContext {
  57. // save <name,indexlist> for input
  58. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> input_map;
  59. // save <name,index> for output
  60. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map;
  61. };
  62. struct DelTransposeInfo;
  63. class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
  64. public:
  65. TensorFlowModelParser() {}
  66. ~TensorFlowModelParser() override {}
  67. /**
  68. * @ingroup domi_omg
  69. * @brief Parse the relevant data from the model file and save it to graph
  70. * @param [in] file Path of the model file
  71. * @param [in|out] graph save model information after parsing
  72. * @return SUCCESS parse successfully
  73. * @return FAILED parse failed
  74. */
  75. Status Parse(const char *file, ge::Graph &graph) override;
  76. /**
  77. * @ingroup domi_omg
  78. * @brief Parse the relevant data from memory and save it to graph
  79. * @param [in] memory buffer of model file
  80. * @param [in] buffer size
  81. * @param [in|out] graph graph for saving model information
  82. * @return SUCCESS parse successfully
  83. * @return FAILED parse failed
  84. */
  85. Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
  86. Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
  87. (void)data;
  88. (void)size;
  89. (void)graph;
  90. return domi::SUCCESS;
  91. }
  92. /**
  93. * @ingroup domi_omg
  94. * @brief Convert model files to JSON format
  95. * @param [in] model_file Model file path to be converted
  96. * @param [out] json_file Converted JSON file path
  97. * @return SUCCESS parse successfully
  98. * @return others parse failed
  99. */
  100. Status ToJson(const char *model_file, const char *json_file) override;
  101. /**
  102. * @ingroup domi_omg
  103. * @brief Parse the relevant data from the model file and save it to graph
  104. * @param [in] graph_def input tensorflow model
  105. * @param [in|out] graph save model informati:on after parsing
  106. * @return SUCCESS parse successfully
  107. * @return FAILED parse failed
  108. */
  109. Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override;
  110. Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto,
  111. domi::GetGraphCallback callback,
  112. ge::ComputeGraphPtr &root_graph) override;
  113. /*
  114. * @ingroup domi_omg
  115. * @brief Mapping TF's datatype to GE's datatype
  116. * @param [in] type, datatype types of operators in TF networks
  117. * @return ge::DataType
  118. */
  119. ge::DataType ConvertToGeDataType(const uint32_t type) override;
  120. Status ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override ;
  121. /**
  122. * @ingroup domi_omg
  123. * @brief Analyze network model data
  124. * @param [in] proto serialized network model
  125. * @param [in|out] graph Save the network information after analysis
  126. * @return SUCCESS
  127. * @return Others failed
  128. */
  129. Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) override;
  130. /**
  131. * @ingroup domi_omg
  132. * @brief Analyze callback model data in subgraph
  133. * @param [in] proto serialized network model
  134. * @param [in] callback callback of subgraph
  135. * @param [in|out] graph Save the network information after analysis
  136. * @return SUCCESS
  137. * @return Others failed
  138. */
  139. Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback,
  140. ge::ComputeGraphPtr &root_graph) override;
  141. bool HasError() override {
  142. return PreChecker::Instance().HasError();
  143. }
  144. Status Save(const string &file) override {
  145. return PreChecker::Instance().Save(file);
  146. }
  147. void Clear() override {
  148. PreChecker::Instance().Clear();
  149. }
  150. private:
  151. Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph);
  152. /**
  153. * @ingroup domi_omg
  154. * @brief Add node information to graph
  155. * @param [in|out] op_node_name_list
  156. * @param [in|out] graph save model information after parsing
  157. * @return SUCCESS add successfully
  158. * @return FAILED add failed
  159. */
  160. Status AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr<ge::ScopeGraph> &scope_graph,
  161. vector<string> &op_node_name_list, bool is_dataset_init = false);
  162. Status AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &graph, const vector<string> &op_node_name_list);
  163. /**
  164. * @ingroup domi_omg
  165. * @brief Add node def into node map
  166. * @param NodeDef*
  167. * @return SUCCESS add successfully
  168. * @return FAILED add failed
  169. */
  170. Status AddFmkNodeDefToMap(const domi::tensorflow::NodeDef *node_def,
  171. vector<string> &op_node_name_list);
  172. /**
  173. * @ingroup domi_omg
  174. * @brief Add node information to graph
  175. * @param [in] layer layer infomation
  176. * @param [in|out] graph save model information after parsing
  177. * @return SUCCESS add successfully
  178. * @return FAILED add failed
  179. */
  180. Status AddNode(const domi::tensorflow::NodeDef *node_def,
  181. ge::ComputeGraphPtr &graph,
  182. shared_ptr<ge::ScopeGraph> &scope_graph);
  183. /**
  184. * @ingroup domi_omg
  185. * @brief Add edge information to graph
  186. * @param [in|out] graph save model information after parsing
  187. * @return SUCCESS add successfully
  188. * @return FAILED add failed
  189. */
  190. Status AddEdges(ge::ComputeGraphPtr &graph);
  191. /**
  192. * @ingroup domi_omg
  193. * @brief get op context from the parsed graph
  194. */
  195. Status GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def);
  196. /**
  197. * @ingroup domi_omg
  198. * @brief get input,include opNode and constNode
  199. * @param [in] op_node_name op name
  200. * @param [out] input_map input node and index
  201. * @return SUCCESS get successfully
  202. * @return FAILED get failed
  203. */
  204. Status GetOpNodeInputMap(const string &op_node_name,
  205. map<string, std::vector<std::pair<int32_t, int32_t>>> &input_map);
  206. /**
  207. * @ingroup domi_omg
  208. * @brief get output of node
  209. * @param [in] graph_def graph
  210. * @return SUCCESS get successfully
  211. * @return FAILED get failed
  212. */
  213. Status GetOpNodeOutputMap(const domi::tensorflow::GraphDef &graph_def);
  214. /**
  215. * @ingroup domi_omg
  216. * @brief whether const OP need to update context
  217. * @param const op name
  218. * @return true or false
  219. */
  220. bool ConstOpNeedUpdate(const string &op_name);
  221. static Status ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *const graph_def,
  222. shared_ptr<ge::ScopeGraph> &scope_graph);
  223. /**
  224. * @ingroup domi_omg
  225. * @brief Run the scope fusion optimizer in list scope_passes_list
  226. * @param [in] scope_passes_list optimizer list
  227. * @param [in/out] pass_manager an object to manager the optimizers
  228. * @param [in/out] scope_graph Save the result of scope fusion
  229. * @return SUCCESS Run successfully
  230. * @return others Run failed
  231. */
  232. static Status RunScopeFusionPass(const vector<string> &scope_passes_list,
  233. ScopePassManager &pass_manager,
  234. shared_ptr<ge::ScopeGraph> &scope_graph);
  235. /**
  236. * @ingroup domi_omg
  237. * @brief Check whether the nodedef parsed from pb is a fusion operator, put NodeDef into fusion_op_nodedef_map_
  238. * @param [in] graph_def Parsed tensorflow:: graphdef object
  239. * @return maybe a fusion operator
  240. */
  241. bool MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def);
  242. /**
  243. * @Confirm whether it is a child operator of the fusion operator
  244. */
  245. bool IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info);
  246. /**
  247. * @brief Inner child operators of fusion operators
  248. */
  249. static bool FusionOpChildIgnore(const shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info);
  250. // Is it a fusion operator
  251. static bool IsFusionOp(const shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def);
  252. /**
  253. * @brief get inPut index of the fusion operator
  254. * @param [in] info Child node description of fusion operator
  255. * @param [in] old_index Child node original index
  256. * @return old_index as input index of the fusion operator
  257. * @return return code
  258. */
  259. static Status GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph,
  260. const ge::ScopeFusionOpInfo &info,
  261. const int32_t old_index,
  262. int32_t &new_index);
  263. /**
  264. * @brief get output index of the fusion operator
  265. * @param [in] info Child node description of fusion operator
  266. * @param [in] old_index Child node original index
  267. * @return old_index as output index of the fusion operator
  268. * @return return code
  269. */
  270. static Status GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph,
  271. const ge::ScopeFusionOpInfo &info,
  272. const int32_t old_index,
  273. int32_t &new_index);
  274. /**
  275. * @ingroup domi_omg
  276. * @brief Update input-output relationships of all operators
  277. * @param graph_def和op_node_name_list
  278. * @return SUCCESS
  279. * @return FAILED
  280. */
  281. Status UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &op_node_name_list);
  282. /**
  283. * @ingroup domi_omg
  284. * @brief Updating the input-output relationship of fusion operators
  285. * @param info Description of fusion operator
  286. * @param fusion_op_node_context Input-output relationship of fusion operator
  287. * @param normal_op_node_context Input-output relationship of normal operators
  288. * @return SUCCESS
  289. * @return FAILED
  290. */
  291. Status UpdateFusionOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  292. OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
  293. /**
  294. * @ingroup domi_omg
  295. * @brief Updating the input-output relationship of normal operators
  296. * @param normal_op_node_context Input-output relationship of normal operators
  297. * @return SUCCESS
  298. * @return FAILED
  299. */
  300. Status UpdateNormalOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const string &op_node_name,
  301. OpNodeContext &normal_op_node_context);
  302. Status EraseNormalOpOutputIfChild(shared_ptr<ge::ScopeGraph> &scope_graph, const string &op_node_name,
  303. OpNodeContext &normal_op_node_context);
  304. /**
  305. * @ingroup domi_omg
  306. * @brief Normalized I / O relationship: de duplication and de outliers
  307. */
  308. Status NormalizeAllNodeOpContext();
  309. /**
  310. * @ingroup domi_omg
  311. * @brief Normalized I / O relationship: according to context map, de duplicate and de outliers
  312. */
  313. Status NormalizeInputOrOutputMap(const string &node_name,
  314. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map);
  315. /**
  316. * @ingroup domi_omg
  317. * @brief delete fusionNodeDef
  318. */
  319. void DeleteFuisonNodeDef();
  320. /**
  321. * @ingroup domi_omg
  322. * @brief Save the control attribute to edges control map
  323. */
  324. void SaveEdgesControlInfo(const string &node_name, const bool control);
  325. /**
  326. * @ingroup domi_omg
  327. * @brief Update the control property to edges control map
  328. */
  329. void UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info);
  330. /**
  331. * @ingroup domi_omg
  332. * @brief get contral information
  333. */
  334. bool GetEdgesControlInfo(const string &node_name, const int32_t index) const;
  335. /**
  336. * @ingroup domi_omg
  337. * @brief Check the validity of input_name
  338. * @param input_node_name,Consider the input: n scenario
  339. * @param index ,return index,"input":return 0,"input:n":return n
  340. * @param index ,control index, input: "^cond/switch_t"
  341. * @return SUCCESS
  342. * @return FAILED
  343. */
  344. static Status CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index, bool *control);
  345. /**
  346. * @ingroup domi_omg
  347. * @brief ge stoi
  348. * @param input_node_name,Consider the input: n scenario
  349. * @param index_str ,stoi param
  350. * @param index ,return index,"input":return 0,"input:n":return n
  351. * @return SUCCESS
  352. * @return FAILED
  353. */
  354. static Status GeStoi(const string &input_node_name, const string &index_str, int32_t *index);
  355. /**
  356. * @ingroup domi_omg
  357. * @brief Clearing the error information of non key operators in fusion operators
  358. */
  359. Status ClearFusionOpError(const vector<string> &op_node_name_list);
  360. /**
  361. * @ingroup domi_omg
  362. * @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef
  363. */
  364. Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def);
  365. Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
  366. const vector<NodeDef *> &nodedef_to_optimize);
  367. Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def,
  368. domi::tensorflow::NodeDef *const nodeCurrent) const;
  369. Status OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, map<string, NodeDef *> &nodedef_map,
  370. const std::pair<string, int> &input_data, const std::vector<string> &control_list);
  371. static Status SetDestNodeName(const domi::tensorflow::NodeDef *const node_current,
  372. domi::tensorflow::NodeDef *const node_dest, const int32_t input_idx,
  373. const bool is_control, bool &clear_input_flag);
  374. void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def,
  375. domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const;
  376. static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo);
  377. static void SoftmaxAddAttr(domi::tensorflow::GraphDef *const graph_def);
  378. /**
  379. * @ingroup domi_omg
  380. * @brief Delete isolated nodes in graph
  381. */
  382. static Status RemoveIsolateNode(const ge::ComputeGraphPtr &graph);
  383. /**
  384. * @ingroup domi_omg
  385. * @brief Infer format for input ops.
  386. */
  387. domiTensorFormat_t InferInputFormats();
  388. /**
  389. * @ingroup domi_omg
  390. * @brief Get node format.
  391. */
  392. Status GetNodeFormat(const NodeDef *node, TfTranspose pred_transpose, domiTensorFormat_t &format,
  393. set<const NodeDef *> &visited_node);
  394. /**
  395. * @ingroup domi_omg
  396. * @brief Get format transpose.
  397. */
  398. Status GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) const;
  399. static Status TrimGraph(const domi::tensorflow::GraphDef &input_graph_def,
  400. domi::tensorflow::GraphDef *output_graph_def);
  401. static Status TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def,
  402. domi::tensorflow::GraphDef *const output_graph_def);
  403. static Status TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def,
  404. domi::tensorflow::GraphDef *const output_graph_def);
  405. static string NodeNameFromInput(const string &input_name);
  406. Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) const;
  407. Status CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) const;
  408. static void UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &input_desc,
  409. const size_t input_tensor_num);
  410. static void UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &output_desc,
  411. size_t output_tensor_num);
  412. Status TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, const string &op_type);
  413. Status UppdateInputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  414. OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
  415. Status UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  416. OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
  417. void GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num,
  418. size_t &output_tensor_num) const;
  419. static Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims, bool &valid);
  420. Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type);
  421. /**
  422. * @ingroup domi_omg
  423. * @brief Trans common decorate function to PartitionedCall.
  424. * @param [in] node_def: Node of common function.
  425. * @param [out] op: result of PartitionedCall OpDesc.
  426. * @return 0: SUCCESS / Others: FAILED
  427. */
  428. Status DefunToPartitionedCall(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op) const;
  429. /**
  430. * @ingroup domi_omg
  431. * @brief Calling ParseParams method of fusion operator
  432. * @param op_parser,op parser of the fusion operator
  433. * @return SUCCESS
  434. * @return FAILED
  435. */
  436. Status FusionNodeParseParams(shared_ptr<OpParser> &op_parser,
  437. const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node) const;
  438. /**
  439. * @ingroup domi_omg
  440. * @brief Optimizing const nodes for custom operators
  441. * @param [in] graph_def graph object
  442. * @return true optimize successfully
  443. * @return false optimize failed
  444. *
  445. */
  446. Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def) const;
  447. /**
  448. * @ingroup domi_omg
  449. * @brief Delete input from nodedef
  450. * @param [in] node_def Nodedef object
  451. * @param [in] remove_index_set Index collection of input nodes to be deleted
  452. * @return true remove successfully
  453. * @return false remove failed
  454. *
  455. */
  456. Status RemoveInputs(domi::tensorflow::GraphDef *graph_def,
  457. domi::tensorflow::NodeDef *node_def,
  458. const set<uint32_t> &remove_index_set,
  459. const map<string, NodeDef *> &all_node_map) const;
  460. Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
  461. domi::tensorflow::NodeDef *node_def,
  462. const map<string, NodeDef *> &all_node_map,
  463. const vector<string> &removed_inputs_vec) const;
  464. void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map<string, vector<int>> &remove_inputs_map) const;
  465. /**
  466. * @ingroup domi_omg
  467. * @brief Parse the parameters in nodedef and construct Ge node.
  468. * This function is a thread function,Parallel parse nodedef in tensorflow graph
  469. * The member variables that need to be modified in this function should be locked
  470. * @param [in] parser TensorFlowModelParser
  471. * @param [in] graph ge graph
  472. * @param [in] graphMutex ge graph lock
  473. * @param [in] scope_graph
  474. * @param [in] node_def Nodedef
  475. * @return SUCCESS
  476. * @return FAILED
  477. *
  478. */
  479. static Status ParseNodeDef(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, std::mutex *graphMutex,
  480. shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def,
  481. error_message::Context error_context);
  482. /**
  483. * @ingroup domi_omg
  484. * @brief adape op type
  485. * @param [in] node_def Nodedef
  486. * @param [in] isDatasetInit
  487. * @return SUCCESS adapt successfully
  488. * @return others adapt failed
  489. *
  490. */
  491. Status AdaptOpType(const domi::tensorflow::NodeDef *node_def, bool isDatasetInit);
  492. Status GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def);
  493. Status RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def);
  494. static Status RecordFusionResult(const std::shared_ptr<ge::ScopeGraph> &scope_graph,
  495. const domi::tensorflow::NodeDef *node,
  496. const ge::OpDescPtr &op_desc);
  497. static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library);
  498. Status SetOriginNodeContext(const NodeDef *node_def, OpNodeContext &op_node_context,
  499. const std::vector<std::pair<std::string, int32_t>> &inputs,
  500. const std::vector<std::pair<std::string, int32_t>> &outputs);
  501. static void GetFusionInputInfo(const string &fusion_op_name, OpNodeContext &fusion_context,
  502. std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> &remap_data_input,
  503. std::map<string, std::vector<string>> &remap_ctrl_input,
  504. std::set<string> &fusion_input_nodes);
  505. static void GetFusionOutputInfo(const string &fusion_op_name, OpNodeContext &fusion_context,
  506. std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> &remap_data_output,
  507. std::map<string, std::vector<string>> &remap_ctrl_output,
  508. std::set<string> &fusion_output_nodes);
  509. void UpdateInnerInputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  510. const std::vector<std::string> &inner_nodes_name,
  511. std::set<string> &fusion_input_nodes);
  512. void UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  513. const std::vector<std::string> &inner_nodes_name,
  514. std::set<string> &fusion_output_nodes);
  515. Status UpdateInnerNodeContext(const string &fusion_op_name, const std::vector<std::string> &inner_nodes_name);
  516. Status AddFusionInnerNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph,
  517. const string &fusion_op_name,
  518. vector<string> &node_name_list);
  519. Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list);
  520. static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
  521. std::mutex *const graph_mutex, const domi::tensorflow::NodeDef *node_def);
  522. static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase);
  523. void DumpAllNodeContext(const string &phase) const;
  524. static Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op,
  525. const shared_ptr<OpParser> &op_parser);
  526. static Status CheckAndUpdateInputDesc(const ge::ComputeGraphPtr &compute_graph);
  527. static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes);
  528. static Status AddExternalGraph(const ComputeGraphPtr &root_graph);
  529. /**
  530. * save <node_name, node_def>
  531. */
  532. unordered_map<string, const NodeDef *> nodedef_map_;
  533. /**
  534. * context, Input output relationship
  535. */
  536. unordered_map<string, OpNodeContext> op_node_context_map_;
  537. /**
  538. * Name of node of OP type, corresponding to node of DaVinci
  539. */
  540. std::unordered_map<std::string, ge::NodePtr> node_map_;
  541. /**
  542. * node_map_ Multithreaded write operation is involved, requiring lock protection
  543. */
  544. std::mutex nodeMapMutex_;
  545. /**
  546. * save <node_name, nodeDefList>
  547. */
  548. map<string, vector<const NodeDef *>> fusion_op_nodedef_map_;
  549. // Policy types of fusion operators,true:scope_pass match,false:prefix match
  550. map<string, bool> fusion_op_policy_;
  551. // The names of all children operators and the description of fusion operators
  552. unordered_map<string, ge::ScopeFusionOpInfo> fusion_op_children_;
  553. /**
  554. * save <node_name, {fusionOpName,description}>
  555. */
  556. map<string, vector<string>> fusion_op_type_map_;
  557. /**
  558. * save nodedef of the fusion operator
  559. */
  560. vector<domi::tensorflow::NodeDef *> fusion_nodedef_list;
  561. /**
  562. * control edge,{Key=NodeName,Value=index}
  563. */
  564. map<string, vector<int32_t>> edges_control_map;
  565. unordered_map<string, const domi::tensorflow::NodeDef *> framework_ops_;
  566. /**
  567. * save <node_name, op_type>
  568. */
  569. map<string, string> adaptedOpTypeMap_;
  570. // { node_name <{input_node_name}, {output_node_name}> }
  571. map<string, std::pair<set<string>, set<string>>> node_inputs_outputs_map_;
  572. unordered_map<string, const ge::Operator *> scope_inner_node_map_;
  573. };
  574. /**
  575. * @ingroup domi_omg
  576. * @brief Tensorflow weight parse
  577. */
  578. class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParser {
  579. public:
  580. /**
  581. * @ingroup domi_omg
  582. * @brief Parse weight data from file and save to graph
  583. * @param [in] file Path of weight file after training
  584. * @param [in|out] graph Save weight information after analysis
  585. * @return SUCCESS parse successfully
  586. * @return PARAM_INVALID param invalid
  587. * @return PARSE_WEIGHTS_FAILED parse failed
  588. */
  589. Status Parse(const char *file, ge::Graph &graph) override;
  590. Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
  591. bool HasError() override {
  592. return PreChecker::Instance().HasError();
  593. }
  594. Status Save(const string &file) override {
  595. return PreChecker::Instance().Save(file);
  596. }
  597. void Clear() override {
  598. PreChecker::Instance().Clear();
  599. }
  600. };
  601. } // namespace domi
  602. #endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_