You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_optimizer.cc 23 kB

5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
3 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
3 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
3 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. /**
  2. * Copyright 2020-2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph_optimizer.h"
  17. #include "graph/op_types.h"
  18. #include "common/types_map.h"
  19. #include "common/util.h"
  20. #include "framework/omg/parser/parser_inner_ctx.h"
  21. #include "graph/debug/ge_attr_define.h"
  22. #include "graph/utils/type_utils.h"
  23. #include "graph_functiondef.h"
  24. #include "parser/common/acl_graph_parser_util.h"
  25. #include "register/op_registry.h"
  26. namespace ge {
  27. REGISTER_OPTYPE_DEFINE(TF_MAXIMUM_GRAD, "MaximumGrad");
  28. REGISTER_OPTYPE_DEFINE(TF_MATMUL, "Matmul");
  29. REGISTER_OPTYPE_DEFINE(TFRELU6, "Relu6");
  30. REGISTER_OPTYPE_DEFINE(TF_BATCH_MATMUL, "BatchMatmul");
  31. } // namespace ge
  32. namespace ge {
  33. namespace {
  34. const char RRTVAL_NODE_NAME_SUFFIX[] = "_RetVal";
  35. const char *const kShapeNodeType = "Shape";
  36. const char *const kShapeNodeNamePrefix = "getnext_shape_";
  37. const char *const kIteratorType = "Iterator";
  38. const char *const kIteratorV2Type = "IteratorV2";
  39. const char *const kGetNextType = "IteratorGetNext";
  40. const char *const kDynGetNextType = "DynamicGetNext";
  41. } // namespace
  42. Status ParserGraphOptimizer::FusionFmkop() {
  43. GELOGI("graph_optimizer.cpp && FustionFmkop()");
  44. GE_CHECK_NOTNULL(graph_);
  45. std::unordered_map<string, std::vector<NodePtr>> node_cluser_Map;
  46. GE_CHK_STATUS_RET(MarkForFusion(node_cluser_Map), "find framework node to be fused fail.");
  47. GE_IF_BOOL_EXEC(node_cluser_Map.empty(), return SUCCESS);
  48. for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) {
  49. GE_CHK_STATUS_RET(UpdateGraph(it->second), "fusion framework nodes failed. node:%s", (it->first).c_str());
  50. }
  51. // fuse all fmkop and then delete node
  52. for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) {
  53. for (auto node : it->second) {
  54. GE_CHK_STATUS_RET(GraphUtils::IsolateNode(node, {}), "Isolate removed node: %s, type: %s failed",
  55. node->GetName().c_str(), node->GetType().c_str());
  56. GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(graph_, node),
  57. "Remove node: %s, type: %s without relink failed", node->GetName().c_str(),
  58. node->GetType().c_str());
  59. }
  60. }
  61. return SUCCESS;
  62. }
  63. Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr>> &node_cluster_map) {
  64. GE_CHECK_NOTNULL(graph_);
  65. bool has_get_next = false;
  66. bool has_dyn_get_next = false;
  67. for (auto node : graph_->GetDirectNode()) {
  68. GE_CHECK_NOTNULL(node);
  69. if (node->GetType() == kDynGetNextType) {
  70. has_dyn_get_next = true;
  71. break;
  72. }
  73. GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue);
  74. string type;
  75. GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
  76. if (type == kGetNextType) {
  77. has_get_next = true;
  78. break;
  79. }
  80. }
  81. return GetFusionCluster(has_get_next, has_dyn_get_next, node_cluster_map);
  82. }
  83. Status ParserGraphOptimizer::GetFusionCluster(const bool has_get_next, const bool has_dyn_get_next,
  84. unordered_map<string, vector<NodePtr>> &node_cluster_map) {
  85. GE_CHECK_NOTNULL(graph_);
  86. for (auto node : graph_->GetDirectNode()) {
  87. GE_CHECK_NOTNULL(node);
  88. GE_CHECK_NOTNULL(node->GetOpDesc());
  89. GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != ge::parser::FRAMEWORK_OP_TYPE, continue)
  90. string type;
  91. GE_CHK_STATUS_RET(ge::parser::GetOriginalType(node, type));
  92. if (type == kGetNextType) {
  93. vector<NodePtr> temp_node_cluser;
  94. for (auto in_anchor : node->GetAllInDataAnchors()) {
  95. OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor();
  96. GE_CHECK_NOTNULL(peer_out_anchor);
  97. NodePtr src_node = peer_out_anchor->GetOwnerNode();
  98. GE_CHECK_NOTNULL(src_node);
  99. temp_node_cluser.push_back(src_node);
  100. }
  101. temp_node_cluser.push_back(node);
  102. for (auto out_anchor : node->GetAllOutDataAnchors()) {
  103. GE_CHECK_NOTNULL(out_anchor);
  104. for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) {
  105. GE_CHECK_NOTNULL(in_anchor);
  106. NodePtr dst_node = in_anchor->GetOwnerNode();
  107. GE_CHECK_NOTNULL(dst_node);
  108. GE_CHECK_NOTNULL(dst_node->GetOpDesc());
  109. if ((dst_node->GetName().find(kShapeNodeNamePrefix) != std::string::npos) &&
  110. (dst_node->GetOpDesc()->GetType() == kShapeNodeType)) {
  111. temp_node_cluser.emplace_back(dst_node);
  112. }
  113. }
  114. }
  115. if (temp_node_cluser.size() > 1) {
  116. vector<NodePtr> node_cluser;
  117. node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end());
  118. node_cluster_map[temp_node_cluser[0]->GetName()] = node_cluser;
  119. }
  120. temp_node_cluser.clear();
  121. GELOGI("MarkForFusion, IteratorGetNext graph mark success.");
  122. }
  123. const bool dataset_init = (!has_get_next) && (!has_dyn_get_next) &&
  124. ((type == kIteratorType) || (type == kIteratorV2Type));
  125. if (dataset_init) {
  126. GE_CHK_STATUS_RET(FindFmkNodeCluser(node_cluster_map), "find framework node to be fused fail.");
  127. GELOGI("MarkForFusion, Iterator init graph mark success.");
  128. }
  129. }
  130. return SUCCESS;
  131. }
  132. // find frameworkOP
  133. Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map<string, vector<NodePtr>> &node_cluser_Map) const {
  134. vector<NodePtr> temp_node_cluser;
  135. for (auto node : graph_->GetDirectNode()) {
  136. GE_CHECK_NOTNULL(node);
  137. OpDescPtr temp_node_desc_ptr = node->GetOpDesc();
  138. GE_CHECK_NOTNULL(temp_node_desc_ptr);
  139. GE_IF_BOOL_EXEC(temp_node_desc_ptr->GetType() == ge::parser::DATA_TYPE, continue);
  140. if (temp_node_desc_ptr->GetType() == ge::parser::FRAMEWORK_OP_TYPE &&
  141. (temp_node_desc_ptr->GetName().find(RRTVAL_NODE_NAME_SUFFIX) == string::npos)) {
  142. temp_node_cluser.push_back(node);
  143. } else {
  144. if (temp_node_cluser.size() > 1) {
  145. vector<NodePtr> node_cluser;
  146. node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end());
  147. node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser;
  148. }
  149. temp_node_cluser.clear();
  150. }
  151. }
  152. if (temp_node_cluser.size() > 1) {
  153. vector<NodePtr> node_cluser;
  154. node_cluser.assign(temp_node_cluser.begin(), temp_node_cluser.end());
  155. node_cluser_Map[temp_node_cluser[0]->GetName()] = node_cluser;
  156. }
  157. return SUCCESS;
  158. }
  159. Status CollectNodeFuncs(vector<ge::NodePtr> &nodes, FunctionDefLibrary *library) {
  160. for (auto node : nodes) {
  161. GE_CHECK_NOTNULL(node);
  162. OpDescPtr opDef = node->GetOpDesc();
  163. string funcdefStr;
  164. ge::Buffer funcDefBytes;
  165. GE_IF_BOOL_EXEC(
  166. AttrUtils::GetBytes(opDef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF, funcDefBytes), FunctionDefLibrary funcLib;
  167. GE_CHECK_NOTNULL(funcDefBytes.GetData());
  168. string str(PtrToPtr<uint8_t, char_t>(funcDefBytes.GetData()), funcDefBytes.GetSize());
  169. GELOGI("FUNCDEF: Get function -> %s.", str.c_str()); GE_IF_BOOL_EXEC(
  170. funcLib.ParseFromArray(funcDefBytes.GetData(), funcDefBytes.GetSize()), library->MergeFrom(funcLib)));
  171. }
  172. return SUCCESS;
  173. }
  174. Status ParserGraphOptimizer::UpdateGraph(vector<NodePtr> &nodes) {
  175. ComputeGraphPtr sub_graph = nullptr;
  176. GE_MAKE_SHARED(sub_graph = std::make_shared<ComputeGraph>("subGraph"), sub_graph = nullptr; return PARAM_INVALID);
  177. unordered_map<string, NodePtr> node_map;
  178. vector<InDataAnchorPtr> input_anchors;
  179. vector<OutDataAnchorPtr> output_anchors;
  180. map<OutDataAnchorPtr, vector<InDataAnchorPtr>> output_in_map;
  181. vector<InControlAnchorPtr> input_control_anchors;
  182. vector<OutControlAnchorPtr> output_control_anchors;
  183. GE_CHK_STATUS_RET(InsertNode(sub_graph, nodes, input_anchors, output_anchors, output_in_map, input_control_anchors,
  184. output_control_anchors, node_map),
  185. "insert node to sub_graph failed.");
  186. GE_CHK_STATUS_RET(LinkInnerAnchor(node_map), "Link inner anchor failed.");
  187. std::unique_ptr<NodeDef> node_def(new (std::nothrow) NodeDef()); // tensorflow NodeDef
  188. GE_CHECK_NOTNULL(node_def);
  189. std::unique_ptr<FunctionDefLibrary> func_def_lib(new (std::nothrow) FunctionDefLibrary());
  190. GE_CHECK_NOTNULL(func_def_lib);
  191. // convert graph to FunctionDef
  192. if (nodes.size() == 0) {
  193. REPORT_INNER_ERROR("E19999", "Param nodes size must greater than 0");
  194. GELOGE(FAILED, "node size must greater than 0 .");
  195. return PARAM_INVALID;
  196. }
  197. GE_CHK_STATUS_RET(CollectNodeFuncs(nodes, func_def_lib.get()), "Collect functionDef in nodes failed.");
  198. GE_CHK_STATUS_RET(GraphToFunctionDef::BuildFunctionDef(sub_graph, nodes[0]->GetName(), func_def_lib.get(),
  199. node_def.get(), input_anchors, output_anchors),
  200. "Build functiondef failed.");
  201. string nodefStr;
  202. string funcdefStr;
  203. GE_IF_BOOL_EXEC(!node_def->SerializeToString(&nodefStr),
  204. REPORT_CALL_ERROR("E19999", "Serialize nodedef to string failed");
  205. GELOGE(PARAM_INVALID, "Serialize nodedef to string failed.");
  206. return PARAM_INVALID);
  207. GE_IF_BOOL_EXEC(!func_def_lib->SerializeToString(&funcdefStr),
  208. REPORT_CALL_ERROR("E19999", "Serialize func_def to string failed, ");
  209. GELOGE(PARAM_INVALID, "Serialize func_def to string failed.");
  210. return PARAM_INVALID);
  211. if (nodes.size() == 0) {
  212. GELOGE(FAILED, "nodes is empty.");
  213. return PARAM_INVALID;
  214. }
  215. std::string fusion_op_name;
  216. for (auto node : nodes) {
  217. fusion_op_name += node->GetName();
  218. }
  219. const uint32_t kFusionOpNameMaxLen = 1024;
  220. if (fusion_op_name.size() > kFusionOpNameMaxLen) {
  221. fusion_op_name = nodes[0]->GetName();
  222. }
  223. OpDescPtr fusion_node_opdef = nullptr;
  224. GE_MAKE_SHARED(
  225. fusion_node_opdef = std::make_shared<OpDesc>(fusion_op_name, nodes[0]->GetOpDesc()->GetType()),
  226. fusion_node_opdef = nullptr;
  227. return FAILED);
  228. std::string type = "";
  229. GE_CHK_STATUS_RET(ge::parser::GetOriginalType(nodes[0], type));
  230. (void)AttrUtils::SetStr(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  231. (void)AttrUtils::SetZeroCopyBytes(
  232. fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FUNC_DEF,
  233. Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(funcdefStr.data()), funcdefStr.length()));
  234. (void)AttrUtils::SetZeroCopyBytes(
  235. fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_NODE_DEF,
  236. Buffer::CopyFrom(PtrToPtr<const char_t, const uint8_t>(nodefStr.data()), nodefStr.length()));
  237. (void)AttrUtils::SetInt(fusion_node_opdef, ge::ATTR_NAME_FRAMEWORK_FWK_TYPE, ge::GetParserContext().type);
  238. // reconstruct fusion_node and edges
  239. GE_CHK_STATUS_RET(RebuildOutputAnchors(output_anchors, fusion_node_opdef),
  240. "rebuild output edges to fusion node failed.");
  241. GE_CHK_STATUS_RET(RebuildInputAnchors(input_anchors, fusion_node_opdef),
  242. "rebuild input edges to fusion node failed.");
  243. NodePtr fusion_node = graph_->AddNode(fusion_node_opdef);
  244. // add Anchors
  245. GE_CHK_STATUS_RET(RebuildFusionNode(input_anchors, output_anchors, output_in_map, input_control_anchors,
  246. output_control_anchors, fusion_node),
  247. "rebuild node failed!");
  248. return SUCCESS;
  249. }
  250. Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge::NodePtr> &nodes,
  251. vector<ge::InDataAnchorPtr> &input_anchors,
  252. vector<ge::OutDataAnchorPtr> &output_anchors,
  253. map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map,
  254. vector<ge::InControlAnchorPtr> &input_control_anchors,
  255. vector<ge::OutControlAnchorPtr> &output_control_anchors,
  256. unordered_map<string, ge::NodePtr> &node_map) {
  257. GE_CHECK_NOTNULL(sub_graph);
  258. for (NodePtr node : nodes) {
  259. GE_CHECK_NOTNULL(node);
  260. OpDescPtr op_def = node->GetOpDesc();
  261. NodePtr new_node = sub_graph->AddNode(op_def);
  262. GE_CHECK_NOTNULL(new_node);
  263. node_map[node->GetName()] = new_node;
  264. // Input
  265. for (auto in_anchor : node->GetAllInDataAnchors()) { // data
  266. OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor();
  267. GE_CHECK_NOTNULL(peer_out_anchor);
  268. vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode());
  269. GE_IF_BOOL_EXEC(iter == nodes.end(), input_anchors.emplace_back(in_anchor));
  270. }
  271. // Output
  272. for (auto out_anchor : node->GetAllOutDataAnchors()) {
  273. bool hasOutNode = false;
  274. // data anchor
  275. for (auto peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  276. vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_in_anchor->GetOwnerNode());
  277. GE_IF_BOOL_EXEC(iter == nodes.end(), output_in_map[out_anchor].emplace_back(peer_in_anchor); hasOutNode = true);
  278. }
  279. GE_IF_BOOL_EXEC(hasOutNode, output_anchors.emplace_back(out_anchor));
  280. }
  281. InControlAnchorPtr node_in_control = node->GetInControlAnchor();
  282. GE_IF_BOOL_EXEC(
  283. node_in_control != nullptr, for (auto peer_out_anchor
  284. : node_in_control->GetPeerOutControlAnchors()) {
  285. vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_out_anchor->GetOwnerNode());
  286. GE_IF_BOOL_EXEC(iter == nodes.end(), input_control_anchors.emplace_back(node_in_control));
  287. });
  288. OutControlAnchorPtr node_out_control = node->GetOutControlAnchor();
  289. GE_IF_BOOL_EXEC(
  290. node_out_control != nullptr, for (auto peer_in_control_anchor
  291. : node_out_control->GetPeerInControlAnchors()) {
  292. vector<ge::NodePtr>::iterator iter = find(nodes.begin(), nodes.end(), peer_in_control_anchor->GetOwnerNode());
  293. GE_IF_BOOL_EXEC(iter == nodes.end(), output_control_anchors.emplace_back(node_out_control));
  294. });
  295. }
  296. return SUCCESS;
  297. }
  298. Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map) const {
  299. for (auto node : graph_->GetDirectNode()) {
  300. GE_IF_BOOL_EXEC(node_map.count(node->GetName()) == 0, continue);
  301. NodePtr dst = node_map[node->GetName()];
  302. for (auto in_anchor : node->GetAllInDataAnchors()) {
  303. OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor();
  304. GE_CHECK_NOTNULL(peer_out_anchor);
  305. GE_IF_BOOL_EXEC(node_map.count(peer_out_anchor->GetOwnerNode()->GetName()) == 0, continue);
  306. NodePtr src = node_map[peer_out_anchor->GetOwnerNode()->GetName()];
  307. GE_IF_BOOL_EXEC(ge::GraphUtils::AddEdge(src->GetOutDataAnchor(peer_out_anchor->GetIdx()),
  308. dst->GetInDataAnchor(in_anchor->GetIdx())) != GRAPH_SUCCESS,
  309. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  310. src->GetName().c_str(), src->GetType().c_str(), peer_out_anchor->GetIdx(),
  311. dst->GetName().c_str(), dst->GetType().c_str(), in_anchor->GetIdx());
  312. GELOGE(FAILED,
  313. "LinkInnerAnchor Link data anchor failed, src node: %s, "
  314. "dst node: %s.",
  315. src->GetName().c_str(), dst->GetName().c_str());
  316. return FAILED);
  317. }
  318. InControlAnchorPtr node_in_control = node->GetInControlAnchor();
  319. GE_IF_BOOL_EXEC(
  320. node_in_control != nullptr, for (auto peer_out_ctl_anchor
  321. : node_in_control->GetPeerOutControlAnchors()) {
  322. GE_IF_BOOL_EXEC(node_map.count(peer_out_ctl_anchor->GetOwnerNode()->GetName()) == 0, continue);
  323. NodePtr src_ctrl = node_map[peer_out_ctl_anchor->GetOwnerNode()->GetName()];
  324. GE_IF_BOOL_EXEC(
  325. ge::GraphUtils::AddEdge(src_ctrl->GetOutControlAnchor(), dst->GetInControlAnchor()) != GRAPH_SUCCESS,
  326. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  327. src_ctrl->GetName().c_str(), src_ctrl->GetType().c_str(),
  328. dst->GetName().c_str(), dst->GetType().c_str());
  329. GELOGE(FAILED,
  330. "LinkInnerAnchor Link control anchor failed, src node: "
  331. "%s, dst node: %s.",
  332. src_ctrl->GetName().c_str(), dst->GetName().c_str());
  333. return FAILED);
  334. });
  335. }
  336. return SUCCESS;
  337. }
  338. // rebuild output anchor
  339. Status ParserGraphOptimizer::RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> &output_anchors,
  340. ge::OpDescPtr fusion_op_desc) {
  341. std::vector<int64_t> output_list;
  342. GE_CHECK_NOTNULL(fusion_op_desc);
  343. // create input desc
  344. for (auto out_anchor : output_anchors) {
  345. NodePtr src_node = out_anchor->GetOwnerNode();
  346. GE_CHECK_NOTNULL(src_node);
  347. GeTensorDesc src_out_desc = src_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx());
  348. GE_CHK_BOOL_EXEC(fusion_op_desc->AddOutputDesc(src_out_desc) == ge::GRAPH_SUCCESS, return FAILED);
  349. ge::DataType data_type = src_out_desc.GetDataType();
  350. const std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type);
  351. GE_IF_BOOL_EXEC(
  352. iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(),
  353. REPORT_INNER_ERROR("E19999", "datatype:%d of output:%d in node:%s:%s is not supported",
  354. data_type, out_anchor->GetIdx(), src_node->GetName().c_str(), src_node->GetName().c_str());
  355. GELOGE(PARAM_INVALID, "data_type %s not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
  356. return PARAM_INVALID);
  357. int32_t dtype = iter->second;
  358. output_list.push_back((int64_t)dtype);
  359. GELOGI("FUNCDEF: output_list push_back %d.", dtype);
  360. }
  361. GE_IF_BOOL_EXEC(!output_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_OUT_DATATYPE, output_list));
  362. return SUCCESS;
  363. }
  364. // rebuild input desc
  365. Status ParserGraphOptimizer::RebuildInputAnchors(vector<ge::InDataAnchorPtr> &input_anchors,
  366. ge::OpDescPtr fusion_op_desc) {
  367. std::vector<int64_t> input_list;
  368. GE_CHECK_NOTNULL(fusion_op_desc);
  369. // add input desc
  370. for (auto in_anchor : input_anchors) {
  371. NodePtr dst_node = in_anchor->GetOwnerNode();
  372. GE_CHECK_NOTNULL(dst_node);
  373. auto tensorDescPtr = dst_node->GetOpDesc()->GetInputDescPtr(in_anchor->GetIdx());
  374. GE_CHECK_NOTNULL_EXEC(tensorDescPtr, return domi::FAILED);
  375. if (fusion_op_desc->AddInputDesc(*tensorDescPtr) != GRAPH_SUCCESS) {
  376. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  377. fusion_op_desc->GetName().c_str(), fusion_op_desc->GetType().c_str());
  378. GELOGE(FAILED, "Add fusion_op_desc AddInputDesc failed");
  379. return FAILED;
  380. }
  381. ge::DataType data_type = tensorDescPtr->GetDataType();
  382. const std::map<int32_t, int32_t>::const_iterator iter = GE_TENSORFLOW_DATA_TYPE_MAP.find((int32_t)data_type);
  383. GE_IF_BOOL_EXEC(
  384. iter == GE_TENSORFLOW_DATA_TYPE_MAP.end(),
  385. REPORT_INNER_ERROR("E19999", "datatype:%d of input:%d in node:%s:%s is not supported",
  386. data_type, in_anchor->GetIdx(), dst_node->GetName().c_str(), dst_node->GetName().c_str());
  387. GELOGE(PARAM_INVALID, "data_type %s not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str());
  388. return PARAM_INVALID);
  389. int32_t dtype = iter->second;
  390. input_list.push_back((int64_t)dtype);
  391. GELOGI("FUNCDEF: input_list push_back %d.", dtype);
  392. }
  393. GE_IF_BOOL_EXEC(!input_list.empty(), (void)AttrUtils::SetListInt(fusion_op_desc, ge::T_IN_DATATYPE, input_list));
  394. return SUCCESS;
  395. }
  396. Status ParserGraphOptimizer::RebuildFusionNode(vector<ge::InDataAnchorPtr> &input_anchors,
  397. vector<ge::OutDataAnchorPtr> &output_anchors,
  398. map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map,
  399. vector<ge::InControlAnchorPtr> &input_control_anchors,
  400. vector<ge::OutControlAnchorPtr> &output_control_anchors,
  401. ge::NodePtr fusion_node) {
  402. GE_CHECK_NOTNULL(fusion_node);
  403. int32_t src_index = 0;
  404. for (auto out_anchor : output_anchors) {
  405. for (auto in_anchor : output_in_map[out_anchor]) {
  406. (void)in_anchor->Unlink(out_anchor);
  407. GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(fusion_node->GetOutDataAnchor(src_index), in_anchor),
  408. "Add anchor between fusion node and in anchor node!");
  409. }
  410. src_index++;
  411. }
  412. src_index = 0;
  413. for (auto in_anchor : input_anchors) {
  414. OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor();
  415. out_anchor->Unlink(in_anchor);
  416. GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(out_anchor, fusion_node->GetInDataAnchor(src_index)),
  417. "Add anchor between out anchor node and fusion node!");
  418. src_index++;
  419. }
  420. for (auto out_control_anchor : output_control_anchors) {
  421. for (auto in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  422. in_control_anchor->Unlink(out_control_anchor);
  423. GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(fusion_node->GetOutControlAnchor(), in_control_anchor),
  424. "Add anchor between fusion node and in control anchor node!");
  425. }
  426. }
  427. for (auto in_control_anchor : input_control_anchors) {
  428. for (auto out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  429. out_control_anchor->Unlink(in_control_anchor);
  430. GE_RETURN_WITH_LOG_IF_ERROR(GraphUtils::AddEdge(out_control_anchor, fusion_node->GetInControlAnchor()),
  431. "Add anchor between out control anchor node and fusion node!");
  432. }
  433. }
  434. return SUCCESS;
  435. }
  436. } // namespace ge