| @@ -29,7 +29,21 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, const NodePtr &new_node) { | |||||
| bool HasOneNonDataNode(const ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| int32_t non_data_nums = 0; | |||||
| for (const auto& n : graph->GetDirectNode()) { | |||||
| if (n->GetType() != parser::DATA) { | |||||
| non_data_nums++; | |||||
| } | |||||
| } | |||||
| GELOGD("graph has non data node num is %d", non_data_nums); | |||||
| return (non_data_nums == 1); | |||||
| } | |||||
| Status HandleNewOp(const NodePtr &node, | |||||
| const ComputeGraphPtr &compute_graph, | |||||
| const NodePtr &new_node, | |||||
| bool no_need_change_name) { | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_CHECK_NOTNULL(new_node); | GE_CHECK_NOTNULL(new_node); | ||||
| if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | if (new_node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | ||||
| @@ -37,8 +51,13 @@ Status HandleNewOp(const NodePtr &node, const ComputeGraphPtr &compute_graph, co | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto op_desc = new_node->GetOpDesc(); | auto op_desc = new_node->GetOpDesc(); | ||||
| static std::atomic_long new_node_index(0); | |||||
| auto new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); | |||||
| string new_name; | |||||
| if (no_need_change_name) { | |||||
| new_name = node->GetName(); | |||||
| } else { | |||||
| static std::atomic_long new_node_index(0); | |||||
| new_name = "PartitionedCall_" + new_node->GetName() + "_" + to_string(new_node_index++); | |||||
| } | |||||
| op_desc->SetName(new_name); | op_desc->SetName(new_name); | ||||
| bool ret = ge::AttrUtils::SetListStr(op_desc, | bool ret = ge::AttrUtils::SetListStr(op_desc, | ||||
| ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, | ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, | ||||
| @@ -91,11 +110,12 @@ Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &n | |||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| // add subgraph node to graph. | // add subgraph node to graph. | ||||
| bool no_need_change_name = HasOneNonDataNode(sub_compute_graph); | |||||
| std::vector<NodePtr> input_nodes; | std::vector<NodePtr> input_nodes; | ||||
| for (const auto &n : sub_compute_graph->GetDirectNode()) { | for (const auto &n : sub_compute_graph->GetDirectNode()) { | ||||
| auto new_node = compute_graph->AddNode(n); | auto new_node = compute_graph->AddNode(n); | ||||
| GE_CHECK_NOTNULL(new_node); | GE_CHECK_NOTNULL(new_node); | ||||
| if (HandleNewOp(node, compute_graph, new_node) != SUCCESS) { | |||||
| if (HandleNewOp(node, compute_graph, new_node, no_need_change_name) != SUCCESS) { | |||||
| GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); | GELOGE(FAILED, "Handle new op[%s] for node[%s] failed.", new_node->GetName().c_str(), node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||