|
|
|
@@ -21,6 +21,7 @@ |
|
|
|
#include "common/ge/ge_util.h" |
|
|
|
#include "graph/common/omg_util.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "common/formats/utils/formats_trans_utils.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
Status MultiBatchPass::Run(ComputeGraphPtr graph) { |
|
|
|
@@ -72,6 +73,8 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { |
|
|
|
|
|
|
|
for (const NodePtr &node : bypass_nodes_) { |
|
|
|
if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed", |
|
|
|
node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str()); |
|
|
|
GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -139,11 +142,15 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor |
|
|
|
|
|
|
|
const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); |
|
|
|
if (in_data_anchor == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Index:%u data anchor of node:%s(%s) is nullptr, check invalid", |
|
|
|
SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
const auto &pred_input = in_data_anchor->GetPeerOutAnchor(); |
|
|
|
if (pred_input == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Index:%u data anchor of node:%s(%s), its peer anchor is nullptr, check invalid", |
|
|
|
SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -151,6 +158,8 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor |
|
|
|
if (pred_value == nullptr) { |
|
|
|
pred_value = pred_input; |
|
|
|
} else if (pred_value != pred_input) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Multi pred_value of case node exist in graph:%s, check invalid", |
|
|
|
graph->GetName().c_str()); |
|
|
|
GELOGE(FAILED, "Multi pred_value node exist."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -163,6 +172,7 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor |
|
|
|
} |
|
|
|
|
|
|
|
if (pred_value == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Find Pred Input of case node in graph:%s failed", graph->GetName().c_str()); |
|
|
|
GELOGE(FAILED, "FindPredInput failed, pred_value is null."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -179,14 +189,22 @@ Status MultiBatchPass::GetDynamicType() { |
|
|
|
for (const auto &switch_n : switch_n_nodes_) { |
|
|
|
int32_t dynamic_type = static_cast<int32_t>(FIXED); |
|
|
|
if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(), |
|
|
|
switch_n->GetName().c_str(), switch_n->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (dynamic_type == static_cast<int32_t>(FIXED)) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%d check invalid", ATTR_DYNAMIC_TYPE.c_str(), |
|
|
|
switch_n->GetName().c_str(), switch_n->GetType().c_str(), dynamic_type); |
|
|
|
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%d not same as attr value:%d in node before, " |
|
|
|
"check invalid", |
|
|
|
ATTR_DYNAMIC_TYPE.c_str(), switch_n->GetName().c_str(), switch_n->GetType().c_str(), |
|
|
|
dynamic_type, dynamic_type_); |
|
|
|
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.", |
|
|
|
dynamic_type, dynamic_type_); |
|
|
|
return FAILED; |
|
|
|
@@ -194,6 +212,7 @@ Status MultiBatchPass::GetDynamicType() { |
|
|
|
dynamic_type_ = dynamic_type; |
|
|
|
} |
|
|
|
if (dynamic_type_ == static_cast<int32_t>(FIXED)) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Find Attr:%s in all switcnn node failed", ATTR_DYNAMIC_TYPE.c_str()); |
|
|
|
GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0."); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -211,6 +230,8 @@ Status MultiBatchPass::GetUserDesignateShape() { |
|
|
|
for (const auto &switch_n : switch_n_nodes_) { |
|
|
|
std::vector<std::string> cur_data_name_order; |
|
|
|
if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(), |
|
|
|
switch_n->GetName().c_str(), switch_n->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -219,6 +240,11 @@ Status MultiBatchPass::GetUserDesignateShape() { |
|
|
|
first_check = false; |
|
|
|
} else { |
|
|
|
if (data_name_order_ != cur_data_name_order) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s not same as attr value:%s in node before, " |
|
|
|
"check invalid", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(), |
|
|
|
switch_n->GetName().c_str(), switch_n->GetType().c_str(), |
|
|
|
formats::JoinToString(cur_data_name_order).c_str(), |
|
|
|
formats::JoinToString(data_name_order_).c_str()); |
|
|
|
GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.", |
|
|
|
switch_n->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -226,6 +252,7 @@ Status MultiBatchPass::GetUserDesignateShape() { |
|
|
|
} |
|
|
|
} |
|
|
|
if (data_name_order_.empty()) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Find Attr:%s in all switcnn node failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str()); |
|
|
|
GELOGE(FAILED, "user shape order can not be empty"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -248,6 +275,8 @@ bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape |
|
|
|
if (batch_num == 0) { |
|
|
|
batch_num = tmp_num; |
|
|
|
} else if (batch_num != tmp_num) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Ouput size num:%u of node:%s(%s) not same as output size num:%d of node before, " |
|
|
|
"check invalid", tmp_num, node->GetName().c_str(), node->GetType().c_str(), batch_num); |
|
|
|
GELOGE(FAILED, "Output size of SwitchN not equal;"); |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -259,10 +288,12 @@ bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape |
|
|
|
} |
|
|
|
|
|
|
|
if (batch_shape.empty()) { |
|
|
|
REPORT_INNER_ERROR("E19999", "batch_shape size is empty after GetBatchInfo, check invalid"); |
|
|
|
GELOGE(FAILED, "batch_shape is empty."); |
|
|
|
return false; |
|
|
|
} |
|
|
|
if (combined_batch.empty()) { |
|
|
|
REPORT_INNER_ERROR("E19999", "combined_batch size is empty after GetBatchInfo, check invalid"); |
|
|
|
GELOGE(FAILED, "combined_batch is empty."); |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -271,11 +302,15 @@ bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape |
|
|
|
for (uint32_t i = 1; i < batch_num; i++) { |
|
|
|
size_t tmp_dim_num = batch_shape[i].size(); |
|
|
|
if (dim_num != tmp_dim_num) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu, check invalid", |
|
|
|
dim_num, i, tmp_dim_num); |
|
|
|
GELOGE(FAILED, "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); |
|
|
|
return false; |
|
|
|
} |
|
|
|
size_t tmp_combined_dim_num = combined_batch[i].size(); |
|
|
|
if (combined_dim_num != tmp_combined_dim_num) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu, check invalid", |
|
|
|
combined_dim_num, i, tmp_combined_dim_num); |
|
|
|
GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", |
|
|
|
combined_dim_num, i, tmp_combined_dim_num); |
|
|
|
return false; |
|
|
|
@@ -303,23 +338,32 @@ bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<in |
|
|
|
for (const NodePtr &node : switch_n_nodes_) { |
|
|
|
OpDescPtr op_desc = node->GetOpDesc(); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid"); |
|
|
|
GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); |
|
|
|
return false; |
|
|
|
} |
|
|
|
std::vector<int64_t> output_dims; |
|
|
|
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Get Attr:%s from output:%u tensor of op:%s(%s) failed", |
|
|
|
ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i, |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); |
|
|
|
return false; |
|
|
|
} |
|
|
|
idx_batch_shape.emplace_back(output_dims); |
|
|
|
output_dims.clear(); |
|
|
|
if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Get Attr:%s from output:%u tensor of op:%s(%s) failed", |
|
|
|
ATTR_NAME_COMBINED_DYNAMIC_DIMS.c_str(), i, |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_COMBINED_DYNAMIC_DIMS failed, batch_index=%u.", i); |
|
|
|
return false; |
|
|
|
} |
|
|
|
idx_combined_batch.emplace_back(output_dims); |
|
|
|
} |
|
|
|
if (!CheckDims(idx_batch_shape)) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Attr:%s of all output:%u tensor in switcnn node not equal, or not exist, " |
|
|
|
"check invalid", ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i); |
|
|
|
GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i); |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -351,6 +395,9 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { |
|
|
|
} |
|
|
|
bypass_nodes_.emplace_back(out_node); |
|
|
|
if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed", |
|
|
|
node->GetName().c_str(), node->GetType().c_str(), i, |
|
|
|
out_node->GetName().c_str(), out_node->GetType().c_str(), peer_in_anchor->GetIdx()); |
|
|
|
GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(), |
|
|
|
out_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -359,6 +406,9 @@ Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { |
|
|
|
output_nodes.emplace_back(identity_out_node); |
|
|
|
if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) != |
|
|
|
GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s) failed", |
|
|
|
out_node->GetName().c_str(), out_node->GetType().c_str(), |
|
|
|
identity_out_node->GetName().c_str(), identity_out_node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(), |
|
|
|
out_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -401,6 +451,9 @@ Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDat |
|
|
|
|
|
|
|
// Add switchCase input edge |
|
|
|
if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed", |
|
|
|
pred_value_node->GetName().c_str(), pred_value_node->GetType().c_str(), pred_value->GetIdx(), |
|
|
|
switch_case->GetName().c_str(), switch_case->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(), |
|
|
|
switch_case->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -448,6 +501,7 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const |
|
|
|
const std::vector<std::vector<int64_t>> &combined_batch) { |
|
|
|
OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN); |
|
|
|
if (op_desc == nullptr) { |
|
|
|
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); |
|
|
|
GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -455,41 +509,56 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const |
|
|
|
GELOGI("Create StreamSwitchN op:%s.", name.c_str()); |
|
|
|
OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc(); |
|
|
|
if (pred_desc == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid"); |
|
|
|
GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
NodePtr switch_case_node = graph->AddNode(op_desc); |
|
|
|
if (switch_case_node == nullptr) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); |
|
|
|
GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
uint32_t batch_num = static_cast<uint32_t>(batch_shape.size()); |
|
|
|
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_NUM.c_str(), |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(), |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(), |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set attr ATTR_USER_DESIGNEATE_SHAPE_ORDER failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
for (uint32_t i = 0; i < batch_num; i++) { |
|
|
|
const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); |
|
|
|
if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_name.c_str(), |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); |
|
|
|
if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_combined_batch.c_str(), |
|
|
|
op_desc->GetName().c_str(), op_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
@@ -507,11 +576,15 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const |
|
|
|
Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) { |
|
|
|
InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT); |
|
|
|
if (in_data_anchor == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Index:%u in data anchor of node:%s(%s) is nullptr, check invalid", |
|
|
|
SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); |
|
|
|
if (peer_data_anchor == nullptr) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Index:%u in data anchor of node:%s(%s), its peer ahcnhor is nullptr, check invalid", |
|
|
|
SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -519,11 +592,17 @@ Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr |
|
|
|
|
|
|
|
// Remove SwitchN data input |
|
|
|
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed", |
|
|
|
data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(), |
|
|
|
switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), SWITCH_DATA_INPUT); |
|
|
|
GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(), |
|
|
|
switch_n_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", |
|
|
|
data_input->GetName().c_str(), data_input->GetType().c_str(), |
|
|
|
switch_case->GetName().c_str(), switch_case->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(), |
|
|
|
switch_case->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -535,11 +614,20 @@ Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr |
|
|
|
NodePtr data_output = peer_in_anchor->GetOwnerNode(); |
|
|
|
if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) || |
|
|
|
(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) or " |
|
|
|
"Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed", |
|
|
|
switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), out_data_anchor->GetIdx(), |
|
|
|
data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx(), |
|
|
|
data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(), |
|
|
|
data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx()); |
|
|
|
GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(), |
|
|
|
switch_n_node->GetName().c_str(), data_output->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed", |
|
|
|
switch_case->GetName().c_str(), switch_case->GetType().c_str(), |
|
|
|
data_output->GetName().c_str(), data_output->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(), |
|
|
|
data_output->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -602,10 +690,15 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { |
|
|
|
if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { |
|
|
|
std::string tmp_label; |
|
|
|
if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(), |
|
|
|
cur_desc->GetName().c_str(), cur_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
if (tmp_label != batch_label) { |
|
|
|
REPORT_INNER_ERROR("E19999", "Attr:%s from op:%s(%s) value:%s not equal to expect:%s, check invalid", |
|
|
|
ATTR_NAME_BATCH_LABEL.c_str(), cur_desc->GetName().c_str(), cur_desc->GetType().c_str(), |
|
|
|
tmp_label.c_str(), batch_label.c_str()); |
|
|
|
GELOGE(FAILED, "Reach other batch_branch, node:%s, cur_label:%s, batch_label:%s.", cur_desc->GetName().c_str(), |
|
|
|
tmp_label.c_str(), batch_label.c_str()); |
|
|
|
return FAILED; |
|
|
|
@@ -613,6 +706,8 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { |
|
|
|
} |
|
|
|
GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str()); |
|
|
|
if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(), |
|
|
|
cur_desc->GetName().c_str(), cur_desc->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -625,6 +720,8 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (type == NETOUTPUT) { |
|
|
|
REPORT_CALL_ERROR("E19999", "SReach net_output without Merge, cur_node:%s(%s), check invalid", |
|
|
|
cur_node->GetName().c_str(), cur_node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Reach net_output without Merge, cur_node:%s.", cur_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -661,6 +758,8 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & |
|
|
|
|
|
|
|
GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); |
|
|
|
if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed", |
|
|
|
stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -686,16 +785,17 @@ Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) { |
|
|
|
GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()), |
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()), |
|
|
|
"Merge remove in ctrl edge failed."); |
|
|
|
GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()), |
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()), |
|
|
|
"StreamMerge add in ctrl edge failed."); |
|
|
|
} |
|
|
|
|
|
|
|
for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) { |
|
|
|
GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), |
|
|
|
"Merge remove out ctrl edge failed."); |
|
|
|
GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), |
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), |
|
|
|
out_ctrl_node->GetInControlAnchor()), |
|
|
|
"Merge remove out ctrl edge failed."); |
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), |
|
|
|
"StreamMerge add out ctrl edge failed."); |
|
|
|
} |
|
|
|
return SUCCESS; |
|
|
|
|