|
|
|
@@ -539,10 +539,7 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN |
|
|
|
if (!need_reset_ && TransTransFusion(func_graph, cnode)) { |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> match; |
|
|
|
PreProcessFowardInsert(func_graph, cnode, &match); |
|
|
|
auto status = node_infer_shape_.InferShape(cnode); |
|
|
|
PostProcessFowardInsert(func_graph, cnode, match); |
|
|
|
if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { |
|
|
|
MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope(); |
|
|
|
return lite::RET_ERROR; |
|
|
|
@@ -551,8 +548,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN |
|
|
|
} |
|
|
|
auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH; |
|
|
|
auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC; |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> match; |
|
|
|
PreProcessFowardInsert(func_graph, cnode, &match); |
|
|
|
if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { |
|
|
|
MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); |
|
|
|
return lite::RET_ERROR; |
|
|
|
@@ -562,7 +557,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN |
|
|
|
MS_LOG(ERROR) << "infer shape failed."; |
|
|
|
return lite::RET_ERROR; |
|
|
|
} |
|
|
|
PostProcessFowardInsert(func_graph, cnode, match); |
|
|
|
if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { |
|
|
|
MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); |
|
|
|
return lite::RET_ERROR; |
|
|
|
@@ -629,57 +623,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con |
|
|
|
return lite::RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
std::unordered_map<AnfNodePtr, AnfNodePtr> *match) { |
|
|
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr); |
|
|
|
auto graph_name = GetValue<std::string>(func_graph->get_attr("graph_name")); |
|
|
|
if (sub_inputs_map_.find(graph_name) == sub_inputs_map_.end()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_ASSERT(manager != nullptr); |
|
|
|
auto tr = manager->Transact(); |
|
|
|
for (size_t i = 1; i < cnode->size(); ++i) { |
|
|
|
if (sub_inputs_map_[graph_name].find(cnode->input(i)) == sub_inputs_map_[graph_name].end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
match->insert(std::make_pair(sub_inputs_map_[graph_name][cnode->input(i)], cnode->input(i))); |
|
|
|
tr.SetEdge(cnode, i, sub_inputs_map_[graph_name][cnode->input(i)]); |
|
|
|
tr.Commit(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UnifyFormatPass::PostProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, |
|
|
|
const std::unordered_map<AnfNodePtr, AnfNodePtr> &match) { |
|
|
|
MS_ASSERT(func_graph != nullptr && cnode != nullptr); |
|
|
|
if (match.empty()) { |
|
|
|
return; |
|
|
|
} |
|
|
|
auto manager = func_graph->manager(); |
|
|
|
MS_ASSERT(manager != nullptr); |
|
|
|
auto tr = manager->Transact(); |
|
|
|
for (size_t i = 1; i < cnode->size(); ++i) { |
|
|
|
if (match.find(cnode->input(i)) != match.end()) { |
|
|
|
tr.SetEdge(cnode, i, match.at(cnode->input(i))); |
|
|
|
tr.Commit(); |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(cnode->input(i), prim::kPrimTranspose)) { |
|
|
|
auto trans_cnode = cnode->input(i)->cast<CNodePtr>(); |
|
|
|
for (size_t j = 1; j < trans_cnode->size(); ++j) { |
|
|
|
if (match.find(trans_cnode->input(j)) == match.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
tr.SetEdge(trans_cnode, j, match.at(trans_cnode->input(j))); |
|
|
|
tr.Commit(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { |
|
|
|
MS_ASSERT(cnode != nullptr && sub_graph != nullptr); |
|
|
|
auto subgraph_name = GetValue<std::string>(sub_graph->get_attr("graph_name")); |
|
|
|
sub_inputs_map_[subgraph_name] = {}; |
|
|
|
sub_inputs_map_[sub_graph] = {}; |
|
|
|
auto sub_inputs = sub_graph->get_inputs(); |
|
|
|
for (auto &node : sub_inputs) { |
|
|
|
auto param_node = node->cast<ParameterPtr>(); |
|
|
|
@@ -689,19 +635,52 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr |
|
|
|
node_name = node_name.substr(0, last_underline); |
|
|
|
last_underline = node_name.find_last_of("_"); |
|
|
|
auto index = std::stoi(node_name.substr(last_underline + 1)) + 3; |
|
|
|
if (utils::isa<CNodePtr>(cnode->input(index)) && CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) { |
|
|
|
std::vector<int> shape = {-1}; |
|
|
|
auto trans_cnode = cnode->input(index)->cast<CNodePtr>(); |
|
|
|
param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone()); |
|
|
|
if (utils::isa<CNodePtr>(cnode->input(index))) { |
|
|
|
ShapeVector shape_vec = {-1}; |
|
|
|
auto out_cnode = cnode->input(index)->cast<CNodePtr>(); |
|
|
|
MS_ASSERT(trans_cnode != nullptr); |
|
|
|
auto trans_prim = GetValueNode<PrimitivePtr>(trans_cnode->input(0)); |
|
|
|
if (trans_prim->GetAttr(kInferDone) != nullptr && GetValue<bool>(trans_prim->GetAttr(kInferDone))) { |
|
|
|
shape = node_infer_shape_.GetInputShape(cnode, index); |
|
|
|
auto out_prim = GetValueNode<PrimitivePtr>(out_cnode->input(0)); |
|
|
|
if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue<bool>(out_prim->GetAttr(kInferDone))) { |
|
|
|
param_node->abstract()->set_shape(std::make_shared<abstract::Shape>(shape_vec)); |
|
|
|
} |
|
|
|
auto type = trans_cnode->abstract()->cast<abstract::AbstractTensorPtr>()->element()->GetTypeTrack(); |
|
|
|
std::vector<int64_t> shape_vec(shape.begin(), shape.end()); |
|
|
|
param_node->set_abstract(std::make_shared<abstract::AbstractTensor>(type, shape_vec)); |
|
|
|
} else { |
|
|
|
sub_inputs_map_[subgraph_name][node] = cnode->input(index); |
|
|
|
lite::DataInfo data_info; |
|
|
|
if (utils::isa<ParameterPtr>(cnode->input(index))) { |
|
|
|
if (cnode->input(index)->cast<ParameterPtr>()->has_default()) { |
|
|
|
param_node->set_default_param(cnode->input(index)->cast<ParameterPtr>()->default_param()); |
|
|
|
sub_inputs_map_[sub_graph].push_back(param_node); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info); |
|
|
|
if (status != lite::RET_OK) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); |
|
|
|
if (data_info.data_.empty()) { |
|
|
|
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec)); |
|
|
|
} else { |
|
|
|
param_node->set_default_param(std::make_shared<tensor::Tensor>((TypeId)data_info.data_type_, shape_vec, |
|
|
|
data_info.data_.data(), data_info.data_.size())); |
|
|
|
} |
|
|
|
sub_inputs_map_[sub_graph].push_back(param_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void UnifyFormatPass::ResetSubGraphInput() { |
|
|
|
for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { |
|
|
|
auto &sub_graph = iter->first; |
|
|
|
auto &sub_inputs = iter->second; |
|
|
|
auto manager = sub_graph->manager(); |
|
|
|
MS_ASSERT(manager != nullptr); |
|
|
|
for (auto &sub_input : sub_inputs) { |
|
|
|
auto param_node = sub_graph->add_parameter(); |
|
|
|
MS_ASSERT(param_node != nullptr); |
|
|
|
param_node->set_abstract(sub_input->abstract()->Clone()); |
|
|
|
param_node->set_name(sub_input->name()); |
|
|
|
manager->Replace(sub_input, param_node); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -804,13 +783,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra |
|
|
|
} |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { |
|
|
|
auto origin_inputs = cnode->inputs(); |
|
|
|
for (size_t i = 3; i < cnode->size(); ++i) { |
|
|
|
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() && |
|
|
|
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) { |
|
|
|
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]); |
|
|
|
} |
|
|
|
} |
|
|
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
if (sub_func_graph == nullptr) { |
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); |
|
|
|
@@ -828,7 +800,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra |
|
|
|
(void)BasicProcess(sub_func_graph, false); |
|
|
|
SetSubGraphOutput(cnode, sub_func_graph); |
|
|
|
SetSubGraphAbstract(cnode, sub_func_graph); |
|
|
|
cnode->set_inputs(origin_inputs); |
|
|
|
continue; |
|
|
|
} |
|
|
|
status = HandleGraphNode(func_graph, cnode); |
|
|
|
@@ -836,6 +807,7 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
ResetSubGraphInput(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -858,13 +830,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { |
|
|
|
auto origin_inputs = cnode->inputs(); |
|
|
|
for (size_t i = 3; i < cnode->size(); ++i) { |
|
|
|
if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() && |
|
|
|
sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) { |
|
|
|
cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]); |
|
|
|
} |
|
|
|
} |
|
|
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
if (sub_func_graph == nullptr) { |
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); |
|
|
|
@@ -882,7 +847,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap |
|
|
|
(void)DecreaseTransposeForSingleOp(sub_func_graph); |
|
|
|
SetSubGraphOutput(cnode, sub_func_graph); |
|
|
|
SetSubGraphAbstract(cnode, sub_func_graph); |
|
|
|
cnode->set_inputs(origin_inputs); |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
@@ -904,6 +868,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
ResetSubGraphInput(); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -1010,8 +975,53 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
auto node_list = TopoSort(func_graph->get_return()); |
|
|
|
bool all_op_can_infer = true; |
|
|
|
for (auto &node : node_list) { |
|
|
|
if (!utils::isa<CNodePtr>(node)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (IsSpecialType(cnode)) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { |
|
|
|
auto sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
if (sub_func_graph == nullptr) { |
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); |
|
|
|
all_op_can_infer = false; |
|
|
|
} else { |
|
|
|
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); |
|
|
|
} |
|
|
|
sub_func_graph = GetValueNode<FuncGraphPtr>(cnode->input(1)); |
|
|
|
if (sub_func_graph == nullptr) { |
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); |
|
|
|
all_op_can_infer = false; |
|
|
|
} else { |
|
|
|
all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); |
|
|
|
} |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode); |
|
|
|
if (!cur_op_can_infer) { |
|
|
|
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0)); |
|
|
|
MS_ASSERT(prim != nullptr); |
|
|
|
lite::NotSupportOp::GetInstance()->InsertOp(prim->name()); |
|
|
|
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT); |
|
|
|
all_op_can_infer = false; |
|
|
|
} |
|
|
|
} |
|
|
|
return all_op_can_infer; |
|
|
|
} |
|
|
|
|
|
|
|
bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { |
|
|
|
MS_ASSERT(func_graph != nullptr); |
|
|
|
if (!JudgeAllOpsCanInfer(func_graph)) { |
|
|
|
MS_LOG(ERROR) << "exist op cannot support infer shape."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
need_reset_ = true; |
|
|
|
// insert transpose for some ops whose format must be NHWC, which is depend on framework. |
|
|
|
// In this process, transpose op cannot be fused to restore the original graph. |
|
|
|
@@ -1039,6 +1049,10 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
if (!JudgeAllOpsCanInfer(func_graph)) { |
|
|
|
MS_LOG(ERROR) << "exist op cannot support infer shape."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
// insert transpose for some ops whose format must be NHWC, which is depend on framework. |
|
|
|
// In this process, tranpose can be fused, which the original graph may not be able to restored. |
|
|
|
if (!BasicProcess(func_graph, true)) { |
|
|
|
|