You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_parser.h 28 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_
  17. #define PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_
  18. #include <map>
  19. #include <memory>
  20. #include <set>
  21. #include <string>
  22. #include <unordered_map>
  23. #include <utility>
  24. #include <vector>
  25. #include "graph/compute_graph.h"
  26. #include "graph/ge_attr_value.h"
  27. #include "graph/ge_tensor.h"
  28. #include "graph/op_desc.h"
  29. #include "graph/operator.h"
  30. #include "graph/range_vistor.h"
  31. #include "graph/utils/attr_utils.h"
  32. #include "graph/utils/tensor_utils.h"
  33. #include "omg/parser/model_parser.h"
  34. #include "omg/parser/op_parser.h"
  35. #include "omg/parser/weights_parser.h"
  36. #include "parser/tensorflow/tensorflow_fusion_op_parser.h"
  37. #include "parser/tensorflow/tensorflow_fusionop_util.h"
  38. #include "parser/tensorflow/tensorflow_util.h"
  39. #include "proto/om.pb.h"
  40. #include "proto/tensorflow/graph.pb.h"
  41. #include "proto/tensorflow/node_def.pb.h"
  42. #include "proto/tensorflow/graph_library.pb.h"
  43. #include "external/register/scope/scope_fusion_pass_register.h"
  44. #include "scope/scope_pass_manager.h"
  45. #include "common/parser_utils.h"
  46. using ge::ScopePassManager;
  47. using domi::tensorflow::GraphDef;
  48. using domi::tensorflow::DT_HALF;
  49. using domi::tensorflow::NodeDef;
  50. using domi::tensorflow::GraphDef;
  51. using domi::tensorflow::AttrValue;
  52. using domi::tensorflow::DataType;
  53. using ge::OpParser;
  54. namespace ge {
  55. using std::string;
  56. using std::vector;
  57. using std::set;
  58. using std::map;
  59. using std::unordered_map;
  60. using std::mutex;
  61. using std::shared_ptr;
  62. enum TfTranspose { TO_NCHW, TO_NHWC, NO_TRANSPOSE };
  63. struct OpNodeContext {
  64. // save <name,indexlist> for input
  65. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> input_map;
  66. // save <name,index> for output
  67. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map;
  68. };
  69. struct DelTransposeInfo;
  70. class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser {
  71. public:
  72. TensorFlowModelParser() {}
  73. virtual ~TensorFlowModelParser() {}
  74. /**
  75. * @ingroup domi_omg
  76. * @brief Parse the relevant data from the model file and save it to graph
  77. * @param [in] file Path of the model file
  78. * @param [in|out] graph save model information after parsing
  79. * @return SUCCESS parse successfully
  80. * @return FAILED parse failed
  81. */
  82. Status Parse(const char *file, ge::Graph &graph) override;
  83. /**
  84. * @ingroup domi_omg
  85. * @brief Parse the relevant data from memory and save it to graph
  86. * @param [in] memory buffer of model file
  87. * @param [in] buffer size
  88. * @param [in|out] graph graph for saving model information
  89. * @return SUCCESS parse successfully
  90. * @return FAILED parse failed
  91. */
  92. Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
  93. Status ParseFromMemory(const char *data, uint32_t size, ge::Graph &graph) override {
  94. return domi::SUCCESS;
  95. }
  96. /**
  97. * @ingroup domi_omg
  98. * @brief Convert model files to JSON format
  99. * @param [in] model_file Model file path to be converted
  100. * @param [out] json_file Converted JSON file path
  101. * @return SUCCESS parse successfully
  102. * @return others parse failed
  103. */
  104. Status ToJson(const char *model_file, const char *json_file) override;
  105. /**
  106. * @ingroup domi_omg
  107. * @brief Parse the relevant data from the model file and save it to graph
  108. * @param [in] graph_def input tensorflow model
  109. * @param [in|out] graph save model informati:on after parsing
  110. * @return SUCCESS parse successfully
  111. * @return FAILED parse failed
  112. */
  113. Status ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override;
  114. Status ParseProtoWithSubgraph(const google::protobuf::Message *root_proto,
  115. domi::GetGraphCallback callback,
  116. ge::ComputeGraphPtr &graph) override;
  117. /*
  118. * @ingroup domi_omg
  119. * @brief Mapping TF's datatype to GE's datatype
  120. * @param [in] type, datatype types of operators in TF networks
  121. * @return ge::DataType
  122. */
  123. ge::DataType ConvertToGeDataType(const uint32_t type) override;
  124. Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ;
  125. /**
  126. * @ingroup domi_omg
  127. * @brief Analyze network model data
  128. * @param [in] proto serialized network model
  129. * @param [in|out] graph Save the network information after analysis
  130. * @return SUCCESS
  131. * @return Others failed
  132. */
  133. Status ParseProto(const std::string &serialized_proto, ge::ComputeGraphPtr &graph) override;
  134. /**
  135. * @ingroup domi_omg
  136. * @brief Analyze callback model data in subgraph
  137. * @param [in] proto serialized network model
  138. * @param [in] callback callback of subgraph
  139. * @param [in|out] graph Save the network information after analysis
  140. * @return SUCCESS
  141. * @return Others failed
  142. */
  143. Status ParseProtoWithSubgraph(const std::string &serialized_proto, domi::GetGraphCallbackV2 callback,
  144. ge::ComputeGraphPtr &graph) override;
  145. private:
  146. Status Parse(const char *file, ge::ComputeGraphPtr &graph);
  147. /**
  148. * @ingroup domi_omg
  149. * @brief Add node information to graph
  150. * @param [in|out] op_node_name_list
  151. * @param [in|out] graph save model information after parsing
  152. * @return SUCCESS add successfully
  153. * @return FAILED add failed
  154. */
  155. Status AddFmkNode(ge::ComputeGraphPtr &graph, shared_ptr<ge::ScopeGraph> &scope_graph,
  156. vector<string> &op_node_name_list, bool is_dataset_init = false);
  157. Status AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &graph, const vector<string> &op_node_name_list);
  158. /**
  159. * @ingroup domi_omg
  160. * @brief Add node def into node map
  161. * @param NodeDef*
  162. * @return SUCCESS add successfully
  163. * @return FAILED add failed
  164. */
  165. Status AddFmkNodeDefToMap(const domi::tensorflow::GraphDef &graph_def, const domi::tensorflow::NodeDef *node_def,
  166. vector<string> &op_node_name_list);
  167. /**
  168. * @ingroup domi_omg
  169. * @brief Add node information to graph
  170. * @param [in] layer layer infomation
  171. * @param [in|out] graph save model information after parsing
  172. * @return SUCCESS add successfully
  173. * @return FAILED add failed
  174. */
  175. Status AddNode(const domi::tensorflow::NodeDef *node_def,
  176. ge::ComputeGraphPtr &graph,
  177. shared_ptr<ge::ScopeGraph> &scope_graph);
  178. /**
  179. * @ingroup domi_omg
  180. * @brief Add edge information to graph
  181. * @param [in|out] graph save model information after parsing
  182. * @return SUCCESS add successfully
  183. * @return FAILED add failed
  184. */
  185. Status AddEdges(ge::ComputeGraphPtr &graph);
  186. /**
  187. * @ingroup domi_omg
  188. * @brief get op context from the parsed graph
  189. */
  190. Status GetOpNodesContextFromGraph(const domi::tensorflow::GraphDef &graph_def);
  191. /**
  192. * @ingroup domi_omg
  193. * @brief get input,include opNode and constNode
  194. * @param [in] op_node_name op name
  195. * @param [out] input_map input node and index
  196. * @return SUCCESS get successfully
  197. * @return FAILED get failed
  198. */
  199. Status GetOpNodeInputMap(const string &op_node_name,
  200. map<string, std::vector<std::pair<int32_t, int32_t>>> &input_map);
  201. /**
  202. * @ingroup domi_omg
  203. * @brief get output of node
  204. * @param [in] graph_def graph
  205. * @return SUCCESS get successfully
  206. * @return FAILED get failed
  207. */
  208. Status GetOpNodeOutputMap(const domi::tensorflow::GraphDef &graph_def);
  209. /**
  210. * @ingroup domi_omg
  211. * @brief Verifying the validity of graphdef object parsed by pb
  212. * @param [in] graph_def Parsed tensorflow:: graphdef object
  213. * @return SUCCESS check successfully
  214. * @return FAILED check failed
  215. */
  216. Status CheckGraphDefValid(const domi::tensorflow::GraphDef &graph_def);
  217. /**
  218. * @ingroup domi_omg
  219. * @brief whether const OP need to update context
  220. * @param const op name
  221. * @return true or false
  222. */
  223. bool ConstOpNeedUpdate(const string &op_name);
  224. Status ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *graph_def, shared_ptr<ge::ScopeGraph> &scope_graph);
  225. /**
  226. * @ingroup domi_omg
  227. * @brief Run the scope fusion optimizer in list scope_passes_list
  228. * @param [in] scope_passes_list optimizer list
  229. * @param [in/out] pass_manager an object to manager the optimizers
  230. * @param [in/out] scope_graph Save the result of scope fusion
  231. * @return SUCCESS Run successfully
  232. * @return others Run failed
  233. */
  234. Status RunScopeFusionPass(const vector<string> &scope_passes_list,
  235. ScopePassManager &pass_manager,
  236. shared_ptr<ge::ScopeGraph> &scope_graph);
  237. /**
  238. * @ingroup domi_omg
  239. * @brief Check whether the nodedef parsed from pb is a fusion operator, put NodeDef into fusion_op_nodedef_map_
  240. * @param [in] graph_def Parsed tensorflow:: graphdef object
  241. * @return maybe a fusion operator
  242. */
  243. bool MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def);
  244. /**
  245. * @Confirm whether it is a child operator of the fusion operator
  246. */
  247. bool IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info);
  248. /**
  249. * @brief Inner child operators of fusion operators
  250. */
  251. bool FusionOpChildIgnore(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info);
  252. // Is it a fusion operator
  253. bool IsFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def);
  254. /**
  255. * @brief get inPut index of the fusion operator
  256. * @param [in] info Child node description of fusion operator
  257. * @param [in] old_index Child node original index
  258. * @return old_index as input index of the fusion operator
  259. * @return return code
  260. */
  261. static Status GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph,
  262. const ge::ScopeFusionOpInfo &info,
  263. const int32_t old_index,
  264. int32_t &new_index);
  265. /**
  266. * @brief get output index of the fusion operator
  267. * @param [in] info Child node description of fusion operator
  268. * @param [in] old_index Child node original index
  269. * @return old_index as output index of the fusion operator
  270. * @return return code
  271. */
  272. static Status GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph,
  273. const ge::ScopeFusionOpInfo &info,
  274. const int32_t old_index,
  275. int32_t &new_index);
  276. /**
  277. * @ingroup domi_omg
  278. * @brief Check the validity of fusionop,put it into op_node_name_list if Misjudgement
  279. * @param op_node_name_list
  280. * @return SUCCESS check successfully
  281. * @return FAILED check failed
  282. */
  283. Status CheckFusionOpValid();
  284. /**
  285. * @ingroup domi_omg
  286. * @brief Update input-output relationships of all operators
  287. * @param graph_def和op_node_name_list
  288. * @return SUCCESS
  289. * @return FAILED
  290. */
  291. Status UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::GraphDef &graph_def,
  292. vector<string> &op_node_name_list);
  293. /**
  294. * @ingroup domi_omg
  295. * @brief Updating the input-output relationship of fusion operators
  296. * @param info Description of fusion operator
  297. * @param fusion_op_node_context Input-output relationship of fusion operator
  298. * @param normal_op_node_context Input-output relationship of normal operators
  299. * @return SUCCESS
  300. * @return FAILED
  301. */
  302. Status UpdateFusionOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  303. OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
  304. /**
  305. * @ingroup domi_omg
  306. * @brief Updating the input-output relationship of normal operators
  307. * @param normal_op_node_context Input-output relationship of normal operators
  308. * @return SUCCESS
  309. * @return FAILED
  310. */
  311. Status UpdateNormalOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const string &op_node_name,
  312. OpNodeContext &normal_op_node_context);
  313. Status EraseNormalOpOutputIfChild(shared_ptr<ge::ScopeGraph> &scope_graph, const string &op_node_name,
  314. OpNodeContext &normal_op_node_context);
  315. /**
  316. * @ingroup domi_omg
  317. * @brief Normalized I / O relationship: de duplication and de outliers
  318. */
  319. Status NormalizeAllNodeOpContext();
  320. /**
  321. * @ingroup domi_omg
  322. * @brief Normalized I / O relationship: according to context map, de duplicate and de outliers
  323. */
  324. Status NormalizeInputOrOutputMap(const string &node_name,
  325. std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map);
  326. /**
  327. * @ingroup domi_omg
  328. * @brief delete fusionNodeDef
  329. */
  330. void DeleteFuisonNodeDef();
  331. /**
  332. * @ingroup domi_omg
  333. * @brief Save the control attribute to edges control map
  334. */
  335. void SaveEdgesControlInfo(const string &node_name, const bool control);
  336. /**
  337. * @ingroup domi_omg
  338. * @brief Update the control property to edges control map
  339. */
  340. void UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo &info);
  341. /**
  342. * @ingroup domi_omg
  343. * @brief get contral information
  344. */
  345. bool GetEdgesControlInfo(const string &node_name, const int32_t index);
  346. /**
  347. * @ingroup domi_omg
  348. * @brief Check the validity of input_name
  349. * @param input_node_name,Consider the input: n scenario
  350. * @param index ,return index,"input":return 0,"input:n":return n
  351. * @param index ,control index, input: "^cond/switch_t"
  352. * @return SUCCESS
  353. * @return FAILED
  354. */
  355. Status CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index, bool *control);
  356. /**
  357. * @ingroup domi_omg
  358. * @brief ge stoi
  359. * @param input_node_name,Consider the input: n scenario
  360. * @param index_str ,stoi param
  361. * @param index ,return index,"input":return 0,"input:n":return n
  362. * @return SUCCESS
  363. * @return FAILED
  364. */
  365. Status GeStoi(const string &input_node_name, const string &index_str, int32_t *index);
  366. /**
  367. * @ingroup domi_omg
  368. * @brief Clearing the error information of non key operators in fusion operators
  369. */
  370. Status ClearFusionOpError(const vector<string> &op_node_name_list);
  371. /**
  372. * @ingroup domi_omg
  373. * @brief Delete the connection relationship of the identity operator connecting the Arg node in graphdef
  374. */
  375. Status GraphDefOptimize(domi::tensorflow::GraphDef *graph_def);
  376. /**
  377. * @ingroup domi_omg
  378. * @brief Optimize for Identity/ReadVariableOp operator
  379. * @param [in] graph_def GraphDef to be optimized
  380. * @param [in] nodedef_map Map of all nodes in graph
  381. * @param [in] nodedef_to_optimize vector of NodeDef to be optimized
  382. * @return SUCCESS optimize successfully
  383. * @return others failed
  384. */
  385. Status GraphDefOptimizeIdentity(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
  386. const vector<NodeDef *> &nodedef_to_optimize);
  387. /**
  388. * @ingroup domi_omg
  389. * @brief For the identity operator whose output is "_retval", optimize it.
  390. * @param [in] nodedef_map Map of all nodes in graph
  391. * @param [in] curr_node_name Name of node to be optimized
  392. * @param [in] clear_input_flag Flag of whether to clear the input of the current node
  393. * @return SUCCESS optimize successfully
  394. * @return others failed
  395. */
  396. Status OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, const string &curr_node_name,
  397. bool &clear_input_flag);
  398. Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map,
  399. const vector<NodeDef *> &nodedef_to_optimize);
  400. Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def,
  401. domi::tensorflow::NodeDef *nodeCurrent);
  402. Status OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, map<string, NodeDef *> &nodedef_map,
  403. const std::pair<string, int> &input_data, const std::vector<string> &control_list);
  404. void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *nodeCurrent,
  405. bool &clearInputFlag);
  406. void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo);
  407. void SoftmaxAddAttr(GraphDef *graph_def);
  408. /**
  409. * @ingroup domi_omg
  410. * @brief Delete isolated nodes in graph
  411. */
  412. Status RemoveIsolateNode(ge::ComputeGraphPtr &graph);
  413. /**
  414. * @ingroup domi_omg
  415. * @brief Infer format for input ops.
  416. */
  417. domiTensorFormat_t InferInputFormats();
  418. /**
  419. * @ingroup domi_omg
  420. * @brief Get node format.
  421. */
  422. Status GetNodeFormat(const NodeDef *node, TfTranspose pred_transpose, domiTensorFormat_t &format,
  423. set<const NodeDef *> &visited_node);
  424. /**
  425. * @ingroup domi_omg
  426. * @brief Get format transpose.
  427. */
  428. Status GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc);
  429. Status TrimGraph(const domi::tensorflow::GraphDef &input_graph_def, domi::tensorflow::GraphDef *output_graph_def);
  430. Status TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def,
  431. domi::tensorflow::GraphDef *output_graph_def);
  432. Status TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def,
  433. domi::tensorflow::GraphDef *output_graph_def);
  434. string NodeNameFromInput(const string &input_name);
  435. Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node);
  436. Status CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node);
  437. void UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &input_desc,
  438. const size_t input_tensor_num);
  439. void UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &output_desc,
  440. size_t output_tensor_num);
  441. Status TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, const string &op_type);
  442. Status UppdateInputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  443. OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
  444. Status UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info,
  445. OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context);
  446. void GetInputOutputTensorNum (ge::OpDescPtr &op_desc, size_t &input_tensor_num, size_t &output_tensor_num) const;
  447. Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims, bool &valid);
  448. Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type);
  449. /**
  450. * @ingroup domi_omg
  451. * @brief Trans common decorate function to PartitionedCall.
  452. * @param [in] node_def: Node of common function.
  453. * @param [out] op: result of PartitionedCall OpDesc.
  454. * @return 0: SUCCESS / Others: FAILED
  455. */
  456. Status DefunToPartitionedCall(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op);
  457. /**
  458. * @ingroup domi_omg
  459. * @brief Calling ParseParams method of fusion operator
  460. * @param op_parser,op parser of the fusion operator
  461. * @return SUCCESS
  462. * @return FAILED
  463. */
  464. Status FusionNodeParseParams(shared_ptr<OpParser> &op_parser,
  465. const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node);
  466. /**
  467. * @ingroup domi_omg
  468. * @brief Optimizing const nodes for custom operators
  469. * @param [in] graph_def graph object
  470. * @return true optimize successfully
  471. * @return false optimize failed
  472. *
  473. */
  474. Status OptimizeConstNodes4CustomOp(domi::tensorflow::GraphDef *graph_def);
  475. /**
  476. * @ingroup domi_omg
  477. * @brief Delete input from nodedef
  478. * @param [in] node_def Nodedef object
  479. * @param [in] remove_index_set Index collection of input nodes to be deleted
  480. * @return true remove successfully
  481. * @return false remove failed
  482. *
  483. */
  484. Status RemoveInputs(domi::tensorflow::GraphDef *graph_def,
  485. domi::tensorflow::NodeDef *node_def,
  486. const set<uint32_t> &remove_index_set,
  487. const map<string, NodeDef *> &all_node_map);
  488. Status AddControlEdgeAfterRemoveInputs(domi::tensorflow::GraphDef *graph_def,
  489. domi::tensorflow::NodeDef *node_def,
  490. const map<string, NodeDef *> &all_node_map,
  491. const vector<string> &removed_inputs_vec);
  492. void RemoveInputAttr(domi::tensorflow::NodeDef *node_def, const map<string, vector<int>> &remove_inputs_map);
  493. /**
  494. * @ingroup domi_omg
  495. * @brief Parse the parameters in nodedef and construct Ge node.
  496. * This function is a thread function,Parallel parse nodedef in tensorflow graph
  497. * The member variables that need to be modified in this function should be locked
  498. * @param [in] parser TensorFlowModelParser
  499. * @param [in] graph ge graph
  500. * @param [in] graphMutex ge graph lock
  501. * @param [in] scope_graph
  502. * @param [in] node_def Nodedef
  503. * @return SUCCESS
  504. * @return FAILED
  505. *
  506. */
  507. static Status ParseNodeDef(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, std::mutex *graphMutex,
  508. shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def,
  509. error_message::Context error_context);
  510. /**
  511. * @ingroup domi_omg
  512. * @brief adape op type
  513. * @param [in] node_def Nodedef
  514. * @param [in] isDatasetInit
  515. * @return SUCCESS adapt successfully
  516. * @return others adapt failed
  517. *
  518. */
  519. Status AdaptOpType(const domi::tensorflow::NodeDef *node_def, bool isDatasetInit);
  520. Status GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def);
  521. Status RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def);
  522. static Status RecordFusionResult(std::shared_ptr<ge::ScopeGraph> &scope_graph,
  523. const domi::tensorflow::NodeDef *node,
  524. ge::OpDescPtr &op_def);
  525. Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library);
  526. Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context,
  527. const std::vector<std::pair<std::string, int32_t>> &inputs,
  528. const std::vector<std::pair<std::string, int32_t>> &outputs);
  529. void GetFusionInputInfo(const string &fusion_op_name, OpNodeContext &fusion_context,
  530. std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> &remap_data_input,
  531. std::map<string, std::vector<string>> &remap_ctrl_input,
  532. std::set<string> &fusion_input_nodes);
  533. void GetFusionOutputInfo(const string &fusion_op_name, OpNodeContext &fusion_context,
  534. std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> &remap_data_output,
  535. std::map<string, std::vector<string>> &remap_ctrl_output,
  536. std::set<string> &fusion_output_nodes);
  537. void UpdateInnerInputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  538. const std::vector<std::string> &inner_nodes_name,
  539. std::set<string> &fusion_input_nodes);
  540. void UpdateInnerOutputMap(const string &fusion_op_name, OpNodeContext &fusion_context,
  541. const std::vector<std::string> &inner_nodes_name,
  542. std::set<string> &fusion_output_nodes);
  543. Status UpdateInnerNodeContext(const string &fusion_op_name, const std::vector<std::string> &inner_nodes_name);
  544. Status AddFusionInnerNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph,
  545. const string &fusion_op_name,
  546. vector<string> &node_name_list);
  547. Status AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &node_name_list);
  548. static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph,
  549. std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def);
  550. void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase);
  551. void DumpAllNodeContext(const string &phase);
  552. Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser);
  553. Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph);
  554. static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes);
  555. /**
  556. * save <node_name, node_def>
  557. */
  558. unordered_map<string, const NodeDef *> nodedef_map_;
  559. /**
  560. * context, Input output relationship
  561. */
  562. unordered_map<string, OpNodeContext> op_node_context_map_;
  563. /**
  564. * Name of node of OP type, corresponding to node of DaVinci
  565. */
  566. std::unordered_map<std::string, ge::NodePtr> node_map_;
  567. /**
  568. * node_map_ Multithreaded write operation is involved, requiring lock protection
  569. */
  570. std::mutex nodeMapMutex_;
  571. /**
  572. * save <node_name, nodeDefList>
  573. */
  574. map<string, vector<const NodeDef *>> fusion_op_nodedef_map_;
  575. // Policy types of fusion operators,true:scope_pass match,false:prefix match
  576. map<string, bool> fusion_op_policy_;
  577. // The names of all children operators and the description of fusion operators
  578. unordered_map<string, ge::ScopeFusionOpInfo> fusion_op_children_;
  579. /**
  580. * save <node_name, {fusionOpName,description}>
  581. */
  582. map<string, vector<string>> fusion_op_type_map_;
  583. /**
  584. * save nodedef of the fusion operator
  585. */
  586. vector<domi::tensorflow::NodeDef *> fusion_nodedef_list;
  587. /**
  588. * control edge,{Key=NodeName,Value=index}
  589. */
  590. map<string, vector<int32_t>> edges_control_map;
  591. unordered_map<string, const domi::tensorflow::NodeDef *> framework_ops_;
  592. /**
  593. * save <node_name, op_type>
  594. */
  595. map<string, string> adaptedOpTypeMap_;
  596. // { node_name <{input_node_name}, {output_node_name}> }
  597. map<string, std::pair<set<string>, set<string>>> node_inputs_outputs_map_;
  598. unordered_map<string, const ge::Operator *> scope_inner_node_map_;
  599. };
  600. /**
  601. * @ingroup domi_omg
  602. * @brief Tensorflow weight parse
  603. */
  604. class PARSER_FUNC_VISIBILITY TensorFlowWeightsParser : public domi::WeightsParser {
  605. public:
  606. /**
  607. * @ingroup domi_omg
  608. * @brief Parse weight data from file and save to graph
  609. * @param [in] file Path of weight file after training
  610. * @param [in|out] graph Save weight information after analysis
  611. * @return SUCCESS parse successfully
  612. * @return PARAM_INVALID param invalid
  613. * @return PARSE_WEIGHTS_FAILED parse failed
  614. */
  615. Status Parse(const char *file, ge::Graph &graph) override;
  616. Status ParseFromMemory(const char *data, uint32_t size, ge::ComputeGraphPtr &graph) override;
  617. };
  618. } // namespace domi
  619. #endif // PARSER_TENSORFLOW_TENSORFLOW_PARSER_H_