You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parser_utils.cc 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #include "parser_utils.h"
  14. #include "external/ge/ge_api_types.h"
  15. #include "framework/common/debug/ge_log.h"
  16. #include "common/util.h"
  17. #include "framework/omg/parser/parser_types.h"
  18. #include "graph/anchor.h"
  19. #include "graph/compute_graph.h"
  20. #include "graph/debug/ge_attr_define.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/utils/node_adapter.h"
  23. #include "graph/utils/op_desc_utils.h"
  24. #include "register/op_registry.h"
  25. namespace ge {
  26. namespace {
  27. bool HasOneNonDataNode(const ComputeGraphPtr &graph) {
  28. GE_CHECK_NOTNULL(graph);
  29. int32_t non_data_nums = 0;
  30. for (const auto& n : graph->GetDirectNode()) {
  31. if (n->GetType() != parser::DATA) {
  32. non_data_nums++;
  33. }
  34. }
  35. GELOGD("graph has non data node num is %d", non_data_nums);
  36. return (non_data_nums == 1);
  37. }
  38. Status HandleNewOp(const NodePtr &node,
  39. const ComputeGraphPtr &compute_graph,
  40. const NodePtr &new_node,
  41. bool no_need_change_name) {
  42. GE_CHECK_NOTNULL(node);
  43. GE_CHECK_NOTNULL(new_node);
  44. if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) {
  45. REPORT_CALL_ERROR("E19999", "SetOwnerComputeGraph failed for node:%s", new_node->GetName().c_str());
  46. GELOGE(FAILED, "[Set][OwnerComputeGraph] for node:%s failed.", new_node->GetName().c_str());
  47. return FAILED;
  48. }
  49. auto op_desc = new_node->GetOpDesc();
  50. string new_name;
  51. if (no_need_change_name) {
  52. new_name = node->GetName();
  53. } else {
  54. static std::atomic_long new_node_index(0);
  55. new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++);
  56. }
  57. op_desc->SetName(new_name);
  58. bool ret = ge::AttrUtils::SetListStr(op_desc,
  59. ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES,
  60. std::move(std::vector<std::string>{node->GetName()}));
  61. if (!ret) {
  62. GELOGW("Set %s to %s fail.", ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES.c_str(), op_desc->GetName().c_str());
  63. }
  64. GELOGD("Handle new op[%s] for node[%s] success.", new_node->GetName().c_str(), node->GetName().c_str());
  65. return SUCCESS;
  66. }
  67. }
  68. Status ParserUtils::ExpandOneToManyGraph(const Graph &graph, OutputMapping &output_mapping) {
  69. GELOGD("Begin run ParserUtils::ExpandOneToManyGraph.");
  70. for (const auto &gn : graph.GetDirectNode()) {
  71. NodePtr n = NodeAdapter::GNode2Node(gn);
  72. GE_CHECK_NOTNULL(n);
  73. std::string ori_type;
  74. (void)AttrUtils::GetStr(n->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, ori_type);
  75. domi::ParseOpToGraphFunc parse_op_to_graph_func =
  76. domi::OpRegistry::Instance()->GetParseOpToGraphFunc(n->GetType(), ori_type);
  77. if (parse_op_to_graph_func == nullptr) {
  78. GELOGD("node:%s type:%s ori type:%s has no parse_op_to_graph_func.",
  79. n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str());
  80. continue;
  81. }
  82. GELOGI("node:%s type:%s ori type:%s has registered one to many parser func.",
  83. n->GetName().c_str(), n->GetType().c_str(), ori_type.c_str());
  84. Graph subgraph("one_to_many_graph");
  85. Operator op = OpDescUtils::CreateOperatorFromNode(n);
  86. Status ret = parse_op_to_graph_func(op, subgraph);
  87. if (ret != SUCCESS) {
  88. REPORT_CALL_ERROR("E19999", "Get one to many graph failed for op:%s.", op.GetName().c_str());
  89. GELOGE(FAILED, "[Invoke][ParseOpToGraphFunc]Get one to many graph failed for op:%s.", op.GetName().c_str());
  90. return FAILED;
  91. }
  92. ret = ExpandNodeToSubgraph(subgraph, n, graph, output_mapping);
  93. if (ret != SUCCESS) {
  94. GELOGE(FAILED, "[Invoke][ExpandNodeToSubgraph]Expand one to many graph failed for op:%s.", op.GetName().c_str());
  95. return FAILED;
  96. }
  97. }
  98. GELOGD("run ParserUtils::ExpandOneToManyGraph success.");
  99. return SUCCESS;
  100. }
  101. Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, const Graph &graph,
  102. OutputMapping &output_mapping) {
  103. ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph);
  104. GE_CHECK_NOTNULL(sub_compute_graph);
  105. ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph);
  106. GE_CHECK_NOTNULL(compute_graph);
  107. // add subgraph node to graph.
  108. bool no_need_change_name = HasOneNonDataNode(sub_compute_graph);
  109. std::vector<NodePtr> input_nodes;
  110. for (const auto &n : sub_compute_graph->GetDirectNode()) {
  111. auto new_node = compute_graph->AddNode(n);
  112. GE_CHECK_NOTNULL(new_node);
  113. if (HandleNewOp(node, compute_graph, new_node, no_need_change_name) != SUCCESS) {
  114. GELOGE(FAILED, "[Handle][NewOp][%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str());
  115. return FAILED;
  116. }
  117. if (new_node->GetType() == ge::parser::DATA) {
  118. input_nodes.emplace_back(new_node);
  119. }
  120. }
  121. // handle input context.
  122. Status ret = HandleInputContext(node, input_nodes, compute_graph);
  123. if (ret != SUCCESS) {
  124. GELOGE(FAILED, "[Run][HandleInputContext] failed, node:%s.", node->GetName().c_str());
  125. return FAILED;
  126. }
  127. // handle output context.
  128. std::vector<std::pair<NodePtr, int32_t>> out_node_index = sub_compute_graph->GetGraphOutNodesInfo();
  129. ret = HandleOutputContext(node, out_node_index, output_mapping);
  130. if (ret != SUCCESS) {
  131. GELOGE(FAILED, "[Run][HandleOutputContext] failed, node:%s.", node->GetName().c_str());
  132. return FAILED;
  133. }
  134. graphStatus graph_status = GraphUtils::RemoveNodeWithoutRelink(compute_graph, node);
  135. if (graph_status != GRAPH_SUCCESS) {
  136. REPORT_CALL_ERROR("E19999", "Remove node:%s from graph:%s failed.", node->GetName().c_str(),
  137. compute_graph->GetName().c_str());
  138. GELOGE(FAILED, "[Remove][Node] %s from graph:%s failed.", node->GetName().c_str(),
  139. compute_graph->GetName().c_str());
  140. return FAILED;
  141. }
  142. graph_status = compute_graph->TopologicalSorting();
  143. if (graph_status != GRAPH_SUCCESS) {
  144. REPORT_CALL_ERROR("E19999", "TopologicalSorting failed, graph:%s.", compute_graph->GetName().c_str());
  145. GELOGE(FAILED, "[Invoke][TopologicalSorting] failed, graph:%s.", compute_graph->GetName().c_str());
  146. return FAILED;
  147. }
  148. return SUCCESS;
  149. }
  150. Status ParserUtils::HandleInputContext(const NodePtr &node,
  151. const std::vector<NodePtr> &input_nodes,
  152. const ComputeGraphPtr &compute_graph) {
  153. GE_CHECK_NOTNULL(node);
  154. for (const auto &in_n : input_nodes) {
  155. GE_CHECK_NOTNULL(in_n);
  156. int index;
  157. if (!AttrUtils::GetInt(in_n->GetOpDesc(), ATTR_NAME_INDEX, index)) {
  158. REPORT_INNER_ERROR("E19999", "GetInt failed, node:%s", in_n->GetName().c_str());
  159. GELOGE(FAILED, "[Get][AttrIndex] of node:%s failed.", in_n->GetName().c_str());
  160. return FAILED;
  161. }
  162. GELOGD("Begin to handle input node:%s with index:%d.", in_n->GetName().c_str(), index);
  163. // get node's in data anchor and peer out anchor
  164. auto node_in_anchor = node->GetInDataAnchor(index);
  165. GE_CHECK_NOTNULL(node_in_anchor);
  166. auto src_out_anchor = node_in_anchor->GetPeerOutAnchor();
  167. GE_CHECK_NOTNULL(src_out_anchor);
  168. auto data_out_anchor = in_n->GetOutDataAnchor(0);
  169. GE_CHECK_NOTNULL(data_out_anchor);
  170. for (const auto &peer_in_anchor : data_out_anchor->GetPeerInDataAnchors()) {
  171. // add data edge
  172. graphStatus ret = GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor);
  173. if (ret != GRAPH_SUCCESS) {
  174. REPORT_CALL_ERROR("E19999", "remove edge from %s to %s failed.",
  175. data_out_anchor->GetOwnerNode()->GetName().c_str(),
  176. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  177. GELOGE(FAILED, "[Remove][Edge] from %s to %s failed.", data_out_anchor->GetOwnerNode()->GetName().c_str(),
  178. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  179. return FAILED;
  180. }
  181. ret = GraphUtils::RemoveEdge(src_out_anchor, node_in_anchor);
  182. if (ret != GRAPH_SUCCESS) {
  183. REPORT_CALL_ERROR("E19999", "remove edge from %s to %s failed.",
  184. src_out_anchor->GetOwnerNode()->GetName().c_str(),
  185. node_in_anchor->GetOwnerNode()->GetName().c_str());
  186. GELOGE(FAILED, "[Remove][Edge] from %s to %s failed.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
  187. node_in_anchor->GetOwnerNode()->GetName().c_str());
  188. return FAILED;
  189. }
  190. ret = GraphUtils::AddEdge(src_out_anchor, peer_in_anchor);
  191. if (ret != GRAPH_SUCCESS) {
  192. REPORT_CALL_ERROR("E19999", "add edge from %s to %s failed.",
  193. src_out_anchor->GetOwnerNode()->GetName().c_str(),
  194. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  195. GELOGE(FAILED, "[Add][Edge] from %s to %s failed.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
  196. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  197. return FAILED;
  198. }
  199. // add control edge
  200. if (node->GetInControlAnchor() != nullptr) {
  201. for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) {
  202. if (GraphUtils::AddEdge(out_anchor, peer_in_anchor->GetOwnerNode()->GetInControlAnchor()) != GRAPH_SUCCESS) {
  203. REPORT_CALL_ERROR("E19999", "add control edge from %s to %s failed.",
  204. out_anchor->GetOwnerNode()->GetName().c_str(),
  205. peer_in_anchor->GetOwnerNode()->GetName().c_str());
  206. GELOGE(FAILED, "[Invoke][AddEdge]add control edge from %s to %s failed.",
  207. out_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str());
  208. return FAILED;
  209. }
  210. }
  211. }
  212. }
  213. graphStatus ret = GraphUtils::RemoveNodeWithoutRelink(compute_graph, in_n);
  214. if (ret != GRAPH_SUCCESS) {
  215. REPORT_CALL_ERROR("E19999", "RemoveNodeWithoutRelink failed, graph:%s, node:%s.",
  216. compute_graph->GetName().c_str(), in_n->GetName().c_str());
  217. GELOGE(FAILED, "[Remove][Node] %s failed, graph:%s.", in_n->GetName().c_str(), compute_graph->GetName().c_str());
  218. return FAILED;
  219. }
  220. }
  221. return SUCCESS;
  222. }
  223. Status ParserUtils::HandleOutputContext(const NodePtr &node,
  224. const std::vector<std::pair<NodePtr, int32_t>> &out_node_index,
  225. OutputMapping &output_mapping) {
  226. GE_CHECK_NOTNULL(node);
  227. GELOGD("The size of out node is %zu", out_node_index.size());
  228. for (size_t index = 0; index < out_node_index.size(); index++) {
  229. auto node_out_anchor = node->GetOutDataAnchor(index);
  230. if (node_out_anchor == nullptr) {
  231. continue;
  232. }
  233. NodePtr out_node = out_node_index[index].first;
  234. int32_t out_index = out_node_index[index].second;
  235. GELOGD("Begin to handle output node:%s[%d] with index:%zu", out_node->GetName().c_str(), out_index, index);
  236. std::string key = GenOutputKey({node->GetName(), index});
  237. output_mapping[key] = std::make_pair(out_node->GetName(), out_index);
  238. auto src_out_anchor = out_node->GetOutDataAnchor(out_index); // get out node's out anchor.
  239. GE_CHECK_NOTNULL(src_out_anchor);
  240. for (const auto &dest_in_anchor : node_out_anchor->GetPeerInDataAnchors()) {
  241. graphStatus ret = GraphUtils::RemoveEdge(node_out_anchor, dest_in_anchor);
  242. if (ret != GRAPH_SUCCESS) {
  243. REPORT_CALL_ERROR("E19999", "remove edge from node %s to node %s failed.",
  244. node_out_anchor->GetOwnerNode()->GetName().c_str(),
  245. dest_in_anchor->GetOwnerNode()->GetName().c_str());
  246. GELOGE(FAILED, "[Remove][Edge] from node %s to node %s failed.",
  247. node_out_anchor->GetOwnerNode()->GetName().c_str(),
  248. dest_in_anchor->GetOwnerNode()->GetName().c_str());
  249. return FAILED;
  250. }
  251. ret = GraphUtils::AddEdge(src_out_anchor, dest_in_anchor);
  252. if (ret != GRAPH_SUCCESS) {
  253. REPORT_CALL_ERROR("E19999", "Add edge from %s to %s failed.",
  254. src_out_anchor->GetOwnerNode()->GetName().c_str(),
  255. dest_in_anchor->GetOwnerNode()->GetName().c_str());
  256. GELOGE(FAILED, "[Add][Edge] from %s to %s failed.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
  257. dest_in_anchor->GetOwnerNode()->GetName().c_str());
  258. return FAILED;
  259. }
  260. }
  261. }
  262. return SUCCESS;
  263. }
  264. string ParserUtils::GenOutputKey(const OutputNodeInfo &node_info) {
  265. return node_info.first + ":" + std::to_string(node_info.second);
  266. }
  267. void ParserUtils::UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info) {
  268. std::string key = ParserUtils::GenOutputKey(output_node_info);
  269. auto iter = final_output_nodes.find(key);
  270. if (iter != final_output_nodes.end()) {
  271. output_node_info = iter->second;
  272. GELOGD("Update output node info, origin[%s], now[%s].",
  273. key.c_str(), ParserUtils::GenOutputKey(output_node_info).c_str());
  274. }
  275. }
  276. void ParserUtils::UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes) {
  277. for (auto &tensor_to_node : tensor_to_nodes) {
  278. std::string tensor_name = tensor_to_node.first;
  279. auto &output_node_info = tensor_to_node.second;
  280. UpdateOutputNodeInfo(final_output_nodes, output_node_info);
  281. }
  282. }
  283. } // namespace ge