|
|
|
@@ -25,31 +25,65 @@ |
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "register/op_registry.h" |
|
|
|
#include "graph/common/omg_util.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace { |
|
|
|
constexpr uint8_t kDataInIndex = 0; |
|
|
|
constexpr uint8_t kDataOutIndex = 0; |
|
|
|
constexpr uint8_t kCaseArgIndex = 1; |
|
|
|
const int kDivisionConst = 2; |
|
|
|
const size_t kNumOfGetnextNode = 1; |
|
|
|
|
|
|
|
const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case"; |
|
|
|
const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data"; |
|
|
|
const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node"; |
|
|
|
const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const"; |
|
|
|
const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex"; |
|
|
|
const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_"; |
|
|
|
const char *const kGetNextName = "IteratorV2"; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
inline bool IsGetNextType(const NodePtr &node) { |
|
|
|
std::string original_type; |
|
|
|
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, |
|
|
|
GELOGW("Get original type failed."); return false); |
|
|
|
return (original_type == kGetNextName); |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { |
|
|
|
GE_IF_BOOL_EXEC(graph == nullptr, GELOGE(FAILED, "Original graph is nullptr"); return FAILED); |
|
|
|
if (graph->GetParentGraph() != nullptr) { |
|
|
|
GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (!GetLocalOmgContext().need_multi_batch) { |
|
|
|
GELOGI("No need to process_multi for no_train graph."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
std::vector<NodePtr> data_nodes; |
|
|
|
std::vector<NodePtr> getnext_nosink_nodes; |
|
|
|
std::vector<NodePtr> getnext_sink_nodes; |
|
|
|
if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "[Train_Dynamic] CheckSequenceOfOptions failed."); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "[Train_Dynamic] UpdateNameForInputShapeOfOption failed."); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "[Train_Dynamic] DeleteIdentityInsertByAdapter failed."); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
if (!multibatch::InitDynamicParams(batch_shapes_)) { |
|
|
|
GELOGD("There is no multi-batch options, no need clone multi-batch graph"); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "[Train_Dynamic] Input_shape and dynamic_dims should set correct params."); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str()); |
|
|
|
GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param"); |
|
|
|
if (CollectIoNodes(graph) != SUCCESS) { |
|
|
|
@@ -66,21 +100,14 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { |
|
|
|
|
|
|
|
(void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); |
|
|
|
ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName()); |
|
|
|
if (branch == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed"); |
|
|
|
return OUT_OF_MEMORY; |
|
|
|
} |
|
|
|
GE_IF_BOOL_EXEC(branch == nullptr, GELOGE(OUT_OF_MEMORY, "Create multi batch graph failed"); return OUT_OF_MEMORY); |
|
|
|
(void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_); |
|
|
|
|
|
|
|
graph->InValid(); // Will modify, need topological again. |
|
|
|
graph->Swap(*branch); |
|
|
|
if (CreateRootGraph(graph) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
if (CreateSubgraphs(graph, branch) != SUCCESS) { |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
GE_CHK_STATUS_RET(CreateRootGraph(graph), "Construct root graph failed."); |
|
|
|
GE_CHK_STATUS_RET(CreateOriGraph(branch), "Construct original graph failed.") |
|
|
|
GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); |
|
|
|
GELOGD("MultiBatchClonePass Leave"); |
|
|
|
@@ -95,9 +122,13 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { |
|
|
|
for (const auto &node : graph->GetDirectNode()) { |
|
|
|
if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) { |
|
|
|
all_data_nodes_.emplace_back(node); |
|
|
|
GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "Init params of %s failed.", node->GetName().c_str()); |
|
|
|
} |
|
|
|
if (node->GetType() == DATA) { |
|
|
|
all_data_nodes_.emplace_back(node); |
|
|
|
} else if (node->GetType() == CONSTANT) { |
|
|
|
} else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) { |
|
|
|
all_const_nodes_.emplace_back(node); |
|
|
|
} else if (node->GetType() == NETOUTPUT) { |
|
|
|
all_output_nodes_.emplace_back(node); |
|
|
|
@@ -114,10 +145,16 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { |
|
|
|
} |
|
|
|
|
|
|
|
int64_t data_index = 0; |
|
|
|
size_t getnext_node_count = 0; |
|
|
|
for (size_t i = 0; i < all_data_nodes_.size(); ++i) { |
|
|
|
if (IsGetNextType(all_data_nodes_[i])) { |
|
|
|
// just one getnext node in graph |
|
|
|
getnext_node_count++; |
|
|
|
continue; |
|
|
|
} |
|
|
|
const auto &op_desc = all_data_nodes_[i]->GetOpDesc(); |
|
|
|
if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { |
|
|
|
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i); |
|
|
|
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -133,7 +170,44 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { |
|
|
|
"Remove edge failed"); |
|
|
|
} |
|
|
|
} |
|
|
|
GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.", |
|
|
|
all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(), |
|
|
|
direct_output_.size()); |
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) { |
|
|
|
data_count_from_getnext_ = 0; |
|
|
|
getnext_sink_dynamic_dims_ = false; |
|
|
|
GE_CHECK_NOTNULL(node->GetOpDesc()); |
|
|
|
data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize(); |
|
|
|
if (GetLocalOmgContext().dynamic_node_type == GETNEXT) { |
|
|
|
data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst; |
|
|
|
for (size_t i = 0; i < data_count_from_getnext_; ++i) { |
|
|
|
GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i); |
|
|
|
GELOGD("The %zu data shape from getnext sink is %s.", i, |
|
|
|
formats::JoinToString(output_desc.GetShape().GetDims()).c_str()); |
|
|
|
const auto &dims = output_desc.GetShape().GetDims(); |
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) { |
|
|
|
GELOGD("The %zu data from %s is static.", i, node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
getnext_sink_dynamic_dims_ = true; |
|
|
|
GELOGD("Dynamic dims in the pattern of getnext sink."); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (node->GetOutControlAnchor() != nullptr) { |
|
|
|
for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { |
|
|
|
NodePtr next_node = peer_in_control_anchor->GetOwnerNode(); |
|
|
|
GE_CHECK_NOTNULL(next_node); |
|
|
|
if (next_node->GetType() == CONSTANTOP) { |
|
|
|
out_control_nodes_.emplace_back(next_node); |
|
|
|
GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -144,7 +218,11 @@ Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) { |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { |
|
|
|
GELOGD("Start create root graph of %s.", graph->GetName().c_str()); |
|
|
|
uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size(); |
|
|
|
if (data_count_from_getnext_ != 0) { |
|
|
|
input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode; |
|
|
|
} |
|
|
|
uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize(); |
|
|
|
|
|
|
|
OpDescBuilder op_builder(kMultiBatchCaseNode, CASE); |
|
|
|
@@ -185,6 +263,10 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { |
|
|
|
op_desc->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to add insert attr on case node %s", op_desc->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed"); |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed"); |
|
|
|
@@ -202,7 +284,7 @@ Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) { |
|
|
|
/// @param [in] NodePtr node: index data node. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) { |
|
|
|
Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { |
|
|
|
const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA); |
|
|
|
if (data_desc == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); |
|
|
|
@@ -220,11 +302,15 @@ Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, No |
|
|
|
} |
|
|
|
|
|
|
|
size_t data_index = all_data_nodes_.size(); |
|
|
|
(void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index); |
|
|
|
if (data_count_from_getnext_ != 0) { |
|
|
|
(void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index - kNumOfGetnextNode); |
|
|
|
} else { |
|
|
|
(void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index); |
|
|
|
} |
|
|
|
(void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); |
|
|
|
|
|
|
|
node = graph->AddNode(data_desc); |
|
|
|
if (node == nullptr) { |
|
|
|
shape_node = graph->AddNode(data_desc); |
|
|
|
if (shape_node == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed"); |
|
|
|
return OUT_OF_MEMORY; |
|
|
|
} |
|
|
|
@@ -286,15 +372,19 @@ Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, N |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { |
|
|
|
// Data --> MapIndex --> Case |
|
|
|
NodePtr data_node; |
|
|
|
GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed"); |
|
|
|
// Data/GetDynamicDims --> MapIndex --> Case |
|
|
|
if (!getnext_sink_dynamic_dims_) { |
|
|
|
GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_), "Create data node failed"); |
|
|
|
} else { |
|
|
|
GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_), "Create get dynamic dims node failed"); |
|
|
|
} |
|
|
|
|
|
|
|
NodePtr const_node; |
|
|
|
GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed"); |
|
|
|
|
|
|
|
GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(), |
|
|
|
shape_node_->GetType().c_str(), const_node->GetName().c_str()); |
|
|
|
OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex"); |
|
|
|
op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0)) |
|
|
|
op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0)) |
|
|
|
.AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0)) |
|
|
|
.AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32)); |
|
|
|
|
|
|
|
@@ -309,8 +399,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { |
|
|
|
return OUT_OF_MEMORY; |
|
|
|
} |
|
|
|
|
|
|
|
if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(), |
|
|
|
GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "Failed to add attr for %s.", |
|
|
|
shape_node_->GetName().c_str()); |
|
|
|
if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", shape_node_->GetName().c_str(), |
|
|
|
index_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -328,6 +420,120 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) { |
|
|
|
const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS); |
|
|
|
if (data_desc == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Create multi-batch get dynamic dims node failed"); |
|
|
|
return OUT_OF_MEMORY; |
|
|
|
} |
|
|
|
|
|
|
|
// input of GetDynamicDims is shape_of_each_data, output is gear_info |
|
|
|
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { |
|
|
|
size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size(); |
|
|
|
// add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter |
|
|
|
if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { |
|
|
|
GeTensorDesc tensor_desc; |
|
|
|
tensor_desc.SetFormat(FORMAT_ND); |
|
|
|
tensor_desc.SetDataType(DT_INT32); |
|
|
|
auto ret = data_desc->AddInputDesc(tensor_desc); |
|
|
|
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); |
|
|
|
return FAILED); |
|
|
|
continue; |
|
|
|
} |
|
|
|
GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(input_shape_dims)}), FORMAT_ND, DT_INT32); |
|
|
|
auto ret = data_desc->AddInputDesc(tensor_desc); |
|
|
|
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data"); |
|
|
|
return FAILED); |
|
|
|
} |
|
|
|
GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32); |
|
|
|
auto ret = data_desc->AddOutputDesc(tensor_desc); |
|
|
|
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data"); |
|
|
|
return FAILED); |
|
|
|
|
|
|
|
(void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true); |
|
|
|
|
|
|
|
shape_node = graph->AddNode(data_desc); |
|
|
|
if (shape_node == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Create multi-batch dynamic dims node failed"); |
|
|
|
return OUT_OF_MEMORY; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) { |
|
|
|
if (!getnext_sink_dynamic_dims_) { |
|
|
|
GELOGD("No need to add attr when not insert get dynamic dims node."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str()); |
|
|
|
if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_DATA_COUNT failed"); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
vector<int64_t> shape_info; |
|
|
|
for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) { |
|
|
|
if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 && |
|
|
|
GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) { |
|
|
|
shape_info.emplace_back(0); |
|
|
|
continue; |
|
|
|
} |
|
|
|
shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size()); |
|
|
|
for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) { |
|
|
|
shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j)); |
|
|
|
} |
|
|
|
} |
|
|
|
if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "set ATTR_GETNEXT_SINK_SHAPE_INFO failed"); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) { |
|
|
|
GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str()); |
|
|
|
size_t input_index = 0; |
|
|
|
size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst; |
|
|
|
for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index, |
|
|
|
++input_index) { |
|
|
|
GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index, |
|
|
|
shape_node->GetName().c_str(), input_index); |
|
|
|
auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index); |
|
|
|
auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index)); |
|
|
|
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link getnext %s to getdynamicdims %s", |
|
|
|
getnext_node->GetName().c_str(), shape_node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) { |
|
|
|
if (!GetLocalOmgContext().dynamic_node_type.empty()) { |
|
|
|
if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to set all gears info attr on netoutput %s.", output_node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
if (getnext_sink_dynamic_dims_) { |
|
|
|
GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str()); |
|
|
|
size_t input_index = output_node->GetAllInDataAnchors().size(); |
|
|
|
if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Append input anchor of %s of %zu failed.", output_node->GetName().c_str(), input_index); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex), |
|
|
|
output_node->GetInDataAnchor(input_index)); |
|
|
|
GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link netoutput %s to getdynamicdims %s", |
|
|
|
output_node->GetName().c_str(), shape_node_->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR); |
|
|
|
if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to set getnext sink dynamic attr on netoutput %s.", |
|
|
|
output_node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Create input node for root graph. |
|
|
|
@@ -337,8 +543,10 @@ Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) { |
|
|
|
Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { |
|
|
|
// Data --> Case |
|
|
|
std::vector<NodePtr> all_data_nodes; |
|
|
|
const size_t arg_index = kCaseArgIndex; |
|
|
|
for (size_t i = 0; i < all_data_nodes_.size(); ++i) { |
|
|
|
size_t case_input_index = kCaseArgIndex; |
|
|
|
NodePtr getnext_node = nullptr; |
|
|
|
size_t input_index_of_getnext = 0; |
|
|
|
for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) { |
|
|
|
const auto &node = all_data_nodes_[i]; |
|
|
|
const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
@@ -353,22 +561,60 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { |
|
|
|
op_desc->SetName(node->GetName()); |
|
|
|
const NodePtr &data = graph->AddNode(op_desc); |
|
|
|
GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); |
|
|
|
if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", |
|
|
|
data->GetName().c_str(), case_node_->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
if (IsGetNextType(node)) { |
|
|
|
getnext_node = data; |
|
|
|
input_index_of_getnext = case_input_index; |
|
|
|
case_input_index = case_input_index + data_count_from_getnext_; |
|
|
|
continue; |
|
|
|
} else { |
|
|
|
if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) != |
|
|
|
GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(), |
|
|
|
case_node_->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (SetMaxShapeToData(data) != SUCCESS) { |
|
|
|
if (SetMaxShape(data) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Set max shape of %s failed.", data->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
all_data_nodes.emplace_back(data); |
|
|
|
} |
|
|
|
if (getnext_node != nullptr) { |
|
|
|
if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to link edge for %s.", getnext_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (SetMaxShape(getnext_node) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Set max shape of %s failed.", getnext_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
all_data_nodes.emplace_back(getnext_node); |
|
|
|
} |
|
|
|
|
|
|
|
all_data_nodes_.swap(all_data_nodes); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) { |
|
|
|
GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(), |
|
|
|
case_input_index, case_node_->GetName().c_str()); |
|
|
|
for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) { |
|
|
|
if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index), |
|
|
|
case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to add data edge between %zu Data:%s to %zu Case:%s", out_index, |
|
|
|
getnext_node->GetName().c_str(), case_input_index, case_node_->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
} |
|
|
|
if (getnext_sink_dynamic_dims_) { |
|
|
|
GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "Failed to add link for %s.", |
|
|
|
shape_node_->GetName().c_str()); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Create Const node for root graph. |
|
|
|
@@ -378,7 +624,11 @@ Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) { |
|
|
|
Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { |
|
|
|
// Const --> Case |
|
|
|
std::vector<NodePtr> all_const_nodes; |
|
|
|
const size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); |
|
|
|
size_t arg_index = kCaseArgIndex + all_data_nodes_.size(); |
|
|
|
if (data_count_from_getnext_ != 0) { |
|
|
|
arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode; |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 0; i < all_const_nodes_.size(); ++i) { |
|
|
|
const auto &node = all_const_nodes_[i]; |
|
|
|
const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc()); |
|
|
|
@@ -395,15 +645,35 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { |
|
|
|
const NodePtr &data = graph->AddNode(op_desc); |
|
|
|
GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str()); |
|
|
|
if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", |
|
|
|
data->GetName().c_str(), case_node_->GetName().c_str()); |
|
|
|
GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(), |
|
|
|
case_node_->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
all_const_nodes.emplace_back(data); |
|
|
|
} |
|
|
|
ChangeConstToData(); |
|
|
|
all_const_nodes_.swap(all_const_nodes); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
void MultiBatchClonePass::ChangeConstToData() { |
|
|
|
size_t data_index = all_data_nodes_.size(); |
|
|
|
if (data_count_from_getnext_ != 0) { |
|
|
|
data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data. |
|
|
|
auto &const_node = all_const_nodes_[i]; |
|
|
|
bool need_change_type = true; |
|
|
|
for (const auto &dst_node : out_control_nodes_) { |
|
|
|
if (const_node->GetName() == dst_node->GetName()) { |
|
|
|
GELOGD("No need to change %s to data type.", const_node->GetName().c_str()); |
|
|
|
need_change_type = false; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!need_change_type) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc(); |
|
|
|
op_desc->SetType(DATA); |
|
|
|
(void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight. |
|
|
|
@@ -413,9 +683,6 @@ Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) { |
|
|
|
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); |
|
|
|
(void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1); |
|
|
|
} |
|
|
|
|
|
|
|
all_const_nodes_.swap(all_const_nodes); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
@@ -461,7 +728,8 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "Failed to add edge between %s to netoutput: %s.", |
|
|
|
shape_node_->GetName().c_str(), output->GetName().c_str()); |
|
|
|
all_output_nodes_.clear(); |
|
|
|
all_output_nodes_.emplace_back(node); |
|
|
|
return SUCCESS; |
|
|
|
@@ -473,34 +741,69 @@ Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) { |
|
|
|
/// @param [in] const NodePtr &data: data in Root/Case graph. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { |
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); |
|
|
|
auto data_name = data->GetName(); |
|
|
|
Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) { |
|
|
|
GELOGD("Start set max shape for %s.", data->GetName().c_str()); |
|
|
|
if (!IsGetNextType(data)) { |
|
|
|
if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
} else { |
|
|
|
for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) { |
|
|
|
if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) { |
|
|
|
GELOGE(PARAM_INVALID, "Failed to update max shape of %s.", data->GetName().c_str()); |
|
|
|
return PARAM_INVALID; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) { |
|
|
|
GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index); |
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); |
|
|
|
string data_name = node->GetName(); |
|
|
|
if (IsGetNextType(node)) { |
|
|
|
data_name.append("_").append(std::to_string(out_anchor_index)); |
|
|
|
} |
|
|
|
GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(), |
|
|
|
formats::JoinToString(data_shape.GetDims()).c_str()); |
|
|
|
const auto &dims = data_shape.GetDims(); |
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { |
|
|
|
return SUCCESS; |
|
|
|
if (!IsGetNextType(node)) { |
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { |
|
|
|
GELOGD("No need to do anything for static data."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { |
|
|
|
if (getnext_sink_dynamic_dims_) { |
|
|
|
// need to update shape of Shape_node when getnext node has dynamic data |
|
|
|
GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index), "Failed to update shape of shape node"); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
} |
|
|
|
(void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); |
|
|
|
(void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims()); |
|
|
|
|
|
|
|
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex)); |
|
|
|
GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex)); |
|
|
|
std::vector<std::string> input_dims_str; |
|
|
|
for (size_t i = 0; i < batch_shapes_.size(); ++i) { |
|
|
|
auto shape = data_shape; |
|
|
|
auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape); |
|
|
|
if (ret != SUCCESS) { |
|
|
|
GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str()); |
|
|
|
GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", node->GetName().c_str()); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
tensor.SetShape(shape); |
|
|
|
int64_t tensor_size = 0; |
|
|
|
(void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size); |
|
|
|
string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" + |
|
|
|
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" + |
|
|
|
TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" + |
|
|
|
std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" + |
|
|
|
formats::JoinToString(tensor.GetShape().GetDims()); |
|
|
|
input_dims_str.emplace_back(input_str); |
|
|
|
} |
|
|
|
(void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); |
|
|
|
(void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str); |
|
|
|
|
|
|
|
size_t max_shape_index = 0; |
|
|
|
int64_t max_size = 0; |
|
|
|
@@ -519,18 +822,72 @@ Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) { |
|
|
|
max_shape_index = i; |
|
|
|
} |
|
|
|
} |
|
|
|
return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index); |
|
|
|
} |
|
|
|
|
|
|
|
return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), data, data_shape); |
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Set max shape to Data/GetNext node in root graph. |
|
|
|
/// @param [in] const std::vector<int64_t> &shapes: dims of shape. |
|
|
|
/// @param [in] const NodePtr &data: data in Root/Case graph. |
|
|
|
/// @param [in] GeShape &data_shape: dims of data node. |
|
|
|
/// @param [in] size_t out_anchor_index: out anchor index of data node. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape, |
|
|
|
size_t out_anchor_index) { |
|
|
|
GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str()); |
|
|
|
if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to calculate the batched shape for data node %s, the shapes may not match", |
|
|
|
data->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
|
|
|
|
if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
if (!IsGetNextType(data)) { |
|
|
|
if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
} else { |
|
|
|
if (getnext_sink_dynamic_dims_) { |
|
|
|
// need to update shape of Shape_node when getnext_sink_dynamic |
|
|
|
GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index), "Failed to update shape of shape node"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(), |
|
|
|
formats::ShapeToString(data_shape).c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) { |
|
|
|
GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index, |
|
|
|
node->GetName().c_str()); |
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape(); |
|
|
|
size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst); |
|
|
|
GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index); |
|
|
|
std::vector<int64_t> output_dims = {static_cast<int64_t>(data_shape.GetDims().size())}; |
|
|
|
GeShape output_shape(output_dims); |
|
|
|
output_desc.SetShape(output_shape); |
|
|
|
if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) { |
|
|
|
GELOGE(FAILED, "Update output desc fail."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Update Data node in Subgraph. |
|
|
|
/// @param [in] const NodePtr &data: data in Subgraph. |
|
|
|
/// @param [in] size_t index: The batch index. |
|
|
|
/// @param [in] size_t batch_index: The batch index. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) { |
|
|
|
Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) { |
|
|
|
int node_index = -1; |
|
|
|
if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) { |
|
|
|
GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str()); |
|
|
|
@@ -545,6 +902,8 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index |
|
|
|
|
|
|
|
auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape(); |
|
|
|
const auto &dims = data_shape.GetDims(); |
|
|
|
GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index, |
|
|
|
formats::JoinToString(dims).c_str()); |
|
|
|
if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
@@ -559,35 +918,77 @@ Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index |
|
|
|
} |
|
|
|
|
|
|
|
auto parent_name = data_name.substr(0, pos); |
|
|
|
return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(index), data, data_shape); |
|
|
|
return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex); |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
/// @ingroup ge |
|
|
|
/// @brief Set max shape to Data node in root graph. |
|
|
|
/// @param [in] const std::vector<int64_t> &shapes: dims of shape. |
|
|
|
/// @param [in] const NodePtr &data: data in Root/Case graph. |
|
|
|
/// @param [in] GeShape &data_shape: dims of data node. |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape) { |
|
|
|
// must not be error, the calc result has been checked in function InsertSwitchNForData |
|
|
|
if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) { |
|
|
|
return INTERNAL_ERROR; |
|
|
|
Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) { |
|
|
|
if (data_count_from_getnext_ == 0) { |
|
|
|
GELOGD("No need to change original graph without getnext node."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str()); |
|
|
|
size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode; |
|
|
|
for (const auto &node : graph->GetDirectNode()) { |
|
|
|
if (IsGetNextType(node)) { |
|
|
|
for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) { |
|
|
|
auto out_data_anchor = node->GetOutDataAnchor(out_index); |
|
|
|
GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); |
|
|
|
NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index); |
|
|
|
GE_IF_BOOL_EXEC(data_node == nullptr, GELOGE(INTERNAL_ERROR, "Create %zu data node failed.", |
|
|
|
out_data_anchor->GetIdx()); return INTERNAL_ERROR); |
|
|
|
for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) { |
|
|
|
GE_IF_BOOL_EXEC(in_anchor == nullptr, continue); |
|
|
|
NodePtr dst_node = in_anchor->GetOwnerNode(); |
|
|
|
if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to remove edge between %s to %s", node->GetName().c_str(), |
|
|
|
dst_node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) != |
|
|
|
GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to add edge between %s to %s", data_node->GetName().c_str(), |
|
|
|
dst_node->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (graph->RemoveNode(node) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(GRAPH_FAILED, "Remove node %s failed!", node->GetName().c_str()); |
|
|
|
return GRAPH_FAILED; |
|
|
|
} |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str()); |
|
|
|
return INTERNAL_ERROR; |
|
|
|
NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, |
|
|
|
size_t data_index) { |
|
|
|
size_t out_anchor_index = out_data_anchor->GetIdx(); |
|
|
|
std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index); |
|
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, DATA); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Create data node failed."); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index); |
|
|
|
|
|
|
|
GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str()); |
|
|
|
return SUCCESS; |
|
|
|
OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); |
|
|
|
if (getnext_op_desc == nullptr) { |
|
|
|
GELOGE(OUT_OF_MEMORY, "Op desc of %s is nullptr.", out_data_anchor->GetOwnerNode()->GetName().c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Add %s input desc failed.", op_desc->GetName().c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) { |
|
|
|
GELOGE(INTERNAL_ERROR, "Add %s output desc failed.", op_desc->GetName().c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
NodePtr data_node = graph->AddNode(op_desc); |
|
|
|
GELOGD("Success create %s node.", data_node->GetName().c_str()); |
|
|
|
return data_node; |
|
|
|
} |
|
|
|
|
|
|
|
/// |
|
|
|
@@ -598,17 +999,14 @@ Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) { |
|
|
|
GELOGD("Start create subgraphs for %s.", graph->GetName().c_str()); |
|
|
|
const auto &op_desc = case_node_->GetOpDesc(); |
|
|
|
for (size_t i = 0; i < batch_shapes_.size(); ++i) { |
|
|
|
std::vector<NodePtr> input_nodes; |
|
|
|
std::vector<NodePtr> output_nodes; |
|
|
|
const std::string postfix = kMultiBatchNodePostfix + std::to_string(i); |
|
|
|
ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes); |
|
|
|
if (subgraph == nullptr) { |
|
|
|
GELOGE(FAILED, "Create multi-batch case node failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
|
|
|
|
GE_IF_BOOL_EXEC(subgraph == nullptr, GELOGE(FAILED, "Create multi-batch case node failed"); return FAILED); |
|
|
|
subgraph->SetName("Batch_" + std::to_string(i)); |
|
|
|
subgraph->SetParentNode(case_node_); |
|
|
|
subgraph->SetParentGraph(graph); |
|
|
|
@@ -621,6 +1019,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const |
|
|
|
op_desc->AddSubgraphName(key_name); |
|
|
|
op_desc->SetSubgraphInstanceName(i, subgraph->GetName()); |
|
|
|
|
|
|
|
GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size()); |
|
|
|
for (const auto &data : input_nodes) { |
|
|
|
GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str()); |
|
|
|
} |
|
|
|
@@ -666,6 +1065,7 @@ Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { |
|
|
|
/// @return 0: SUCCESS / others: FAILED |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { |
|
|
|
GELOGD("Start prune direct output."); |
|
|
|
const auto &func_desc = case_node_->GetOpDesc(); |
|
|
|
uint32_t unused_num = 0; |
|
|
|
uint32_t output_num = func_desc->GetOutputsSize(); |
|
|
|
@@ -710,6 +1110,7 @@ Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) { |
|
|
|
/// |
|
|
|
Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) { |
|
|
|
if (unused_num == 0) { |
|
|
|
GELOGD("No need to update output tensor."); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
|