diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.cc b/mindspore/lite/tools/optimizer/graph/node_infershape.cc index 658cccbdd4..041ab24761 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.cc +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.cc @@ -109,6 +109,20 @@ tensor::TensorPtr NewTensorInfo(lite::Tensor *tensor) { } } // namespace +bool NodeInferShape::JudgeOpSupportInfer(const CNodePtr &cnode) { + MS_ASSERT(cnode != nullptr); + auto prim_t = lite::GetPrimitiveT(cnode->input(0)); + if (prim_t == nullptr) { + return false; + } + auto parameter_gen = lite::PopulateRegistry::GetInstance()->GetParameterCreator(prim_t->value.type, lite::SCHEMA_CUR); + if (parameter_gen == nullptr) { + delete prim_t; + return false; + } + return true; +} + STATUS NodeInferShape::InferShape(const CNodePtr &cnode) { MS_ASSERT(cnode != nullptr); auto anf_prim = GetValueNode>(cnode->input(0)); diff --git a/mindspore/lite/tools/optimizer/graph/node_infershape.h b/mindspore/lite/tools/optimizer/graph/node_infershape.h index 8879ac5b81..8c73db960b 100644 --- a/mindspore/lite/tools/optimizer/graph/node_infershape.h +++ b/mindspore/lite/tools/optimizer/graph/node_infershape.h @@ -38,6 +38,7 @@ class NodeInferShape { train_flag_ = train_flag; } STATUS InferShape(const CNodePtr &cnode); + bool JudgeOpSupportInfer(const CNodePtr &cnode); std::vector GetInputShape(const CNodePtr &cnode, size_t index); std::vector GetIntVecInput(const CNodePtr &cnode, size_t index); diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc index dd81a9490e..5d1dfafe7c 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.cc @@ -539,10 +539,7 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN if (!need_reset_ && TransTransFusion(func_graph, cnode)) { return lite::RET_OK; } - std::unordered_map match; - PreProcessFowardInsert(func_graph, cnode, &match); auto status = node_infer_shape_.InferShape(cnode); - PostProcessFowardInsert(func_graph, cnode, match); if (status != lite::RET_OK && status != lite::RET_INFER_INVALID) { MS_LOG(ERROR) << "infer shape failed: " << cnode->fullname_with_scope(); return lite::RET_ERROR; @@ -551,8 +548,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN } auto before_perm = trans_info.pre_ == kNHWC2NCHW ? NH2NC : NC2NH; auto after_perm = trans_info.post_ == kNCHW2NHWC ? NC2NH : NH2NC; - std::unordered_map match; - PreProcessFowardInsert(func_graph, cnode, &match); if (InsertPreTransNode(func_graph, cnode, before_perm) != lite::RET_OK) { MS_LOG(ERROR) << "insert pre node failed." << cnode->fullname_with_scope(); return lite::RET_ERROR; @@ -562,7 +557,6 @@ STATUS UnifyFormatPass::HandleGraphNode(const FuncGraphPtr &func_graph, const CN MS_LOG(ERROR) << "infer shape failed."; return lite::RET_ERROR; } - PostProcessFowardInsert(func_graph, cnode, match); if (InsertPostTransNode(func_graph, cnode, after_perm) != lite::RET_OK) { MS_LOG(ERROR) << "insert post node failed." << cnode->fullname_with_scope(); return lite::RET_ERROR; @@ -629,57 +623,9 @@ STATUS UnifyFormatPass::HandleGraphMultiNode(const FuncGraphPtr &func_graph, con return lite::RET_OK; } -void UnifyFormatPass::PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - std::unordered_map *match) { - MS_ASSERT(func_graph != nullptr && cnode != nullptr); - auto graph_name = GetValue(func_graph->get_attr("graph_name")); - if (sub_inputs_map_.find(graph_name) == sub_inputs_map_.end()) { - return; - } - auto manager = func_graph->manager(); - MS_ASSERT(manager != nullptr); - auto tr = manager->Transact(); - for (size_t i = 1; i < cnode->size(); ++i) { - if (sub_inputs_map_[graph_name].find(cnode->input(i)) == sub_inputs_map_[graph_name].end()) { - continue; - } - match->insert(std::make_pair(sub_inputs_map_[graph_name][cnode->input(i)], cnode->input(i))); - tr.SetEdge(cnode, i, sub_inputs_map_[graph_name][cnode->input(i)]); - tr.Commit(); - } -} - -void UnifyFormatPass::PostProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - const std::unordered_map &match) { - MS_ASSERT(func_graph != nullptr && cnode != nullptr); - if (match.empty()) { - return; - } - auto manager = func_graph->manager(); - MS_ASSERT(manager != nullptr); - auto tr = manager->Transact(); - for (size_t i = 1; i < cnode->size(); ++i) { - if (match.find(cnode->input(i)) != match.end()) { - tr.SetEdge(cnode, i, match.at(cnode->input(i))); - tr.Commit(); - } - if (CheckPrimitiveType(cnode->input(i), prim::kPrimTranspose)) { - auto trans_cnode = cnode->input(i)->cast(); - for (size_t j = 1; j < trans_cnode->size(); ++j) { - if (match.find(trans_cnode->input(j)) == match.end()) { - continue; - } - tr.SetEdge(trans_cnode, j, match.at(trans_cnode->input(j))); - tr.Commit(); - } - } - } -} - void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph) { MS_ASSERT(cnode != nullptr && sub_graph != nullptr); - auto subgraph_name = GetValue(sub_graph->get_attr("graph_name")); - sub_inputs_map_[subgraph_name] = {}; + sub_inputs_map_[sub_graph] = {}; auto sub_inputs = sub_graph->get_inputs(); for (auto &node : sub_inputs) { auto param_node = node->cast(); @@ -689,19 +635,52 @@ void UnifyFormatPass::SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr node_name = node_name.substr(0, last_underline); last_underline = node_name.find_last_of("_"); auto index = std::stoi(node_name.substr(last_underline + 1)) + 3; - if (utils::isa(cnode->input(index)) && CheckPrimitiveType(cnode->input(index), prim::kPrimTranspose)) { - std::vector shape = {-1}; - auto trans_cnode = cnode->input(index)->cast(); + param_node->set_abstract(GetCNodeInputAbstract(cnode, index)->Clone()); + if (utils::isa(cnode->input(index))) { + ShapeVector shape_vec = {-1}; + auto out_cnode = cnode->input(index)->cast(); MS_ASSERT(trans_cnode != nullptr); - auto trans_prim = GetValueNode(trans_cnode->input(0)); - if (trans_prim->GetAttr(kInferDone) != nullptr && GetValue(trans_prim->GetAttr(kInferDone))) { - shape = node_infer_shape_.GetInputShape(cnode, index); + auto out_prim = GetValueNode(out_cnode->input(0)); + if (out_prim->GetAttr(kInferDone) == nullptr || !GetValue(out_prim->GetAttr(kInferDone))) { + param_node->abstract()->set_shape(std::make_shared(shape_vec)); } - auto type = trans_cnode->abstract()->cast()->element()->GetTypeTrack(); - std::vector shape_vec(shape.begin(), shape.end()); - param_node->set_abstract(std::make_shared(type, shape_vec)); } else { - sub_inputs_map_[subgraph_name][node] = cnode->input(index); + lite::DataInfo data_info; + if (utils::isa(cnode->input(index))) { + if (cnode->input(index)->cast()->has_default()) { + param_node->set_default_param(cnode->input(index)->cast()->default_param()); + sub_inputs_map_[sub_graph].push_back(param_node); + } + continue; + } + auto status = lite::FetchDataFromValueNode(cnode, index, fmk_type_, train_flag_, &data_info); + if (status != lite::RET_OK) { + continue; + } + ShapeVector shape_vec(data_info.shape_.begin(), data_info.shape_.end()); + if (data_info.data_.empty()) { + param_node->set_default_param(std::make_shared((TypeId)data_info.data_type_, shape_vec)); + } else { + param_node->set_default_param(std::make_shared((TypeId)data_info.data_type_, shape_vec, + data_info.data_.data(), data_info.data_.size())); + } + sub_inputs_map_[sub_graph].push_back(param_node); + } + } +} + +void UnifyFormatPass::ResetSubGraphInput() { + for (auto iter = sub_inputs_map_.begin(); iter != sub_inputs_map_.end(); ++iter) { + auto &sub_graph = iter->first; + auto &sub_inputs = iter->second; + auto manager = sub_graph->manager(); + MS_ASSERT(manager != nullptr); + for (auto &sub_input : sub_inputs) { + auto param_node = sub_graph->add_parameter(); + MS_ASSERT(param_node != nullptr); + param_node->set_abstract(sub_input->abstract()->Clone()); + param_node->set_name(sub_input->name()); + manager->Replace(sub_input, param_node); } } } @@ -804,13 +783,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra } } if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { - auto origin_inputs = cnode->inputs(); - for (size_t i = 3; i < cnode->size(); ++i) { - if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() && - sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) { - cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]); - } - } auto sub_func_graph = GetValueNode(cnode->input(1)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -828,7 +800,6 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra (void)BasicProcess(sub_func_graph, false); SetSubGraphOutput(cnode, sub_func_graph); SetSubGraphAbstract(cnode, sub_func_graph); - cnode->set_inputs(origin_inputs); continue; } status = HandleGraphNode(func_graph, cnode); @@ -836,6 +807,7 @@ bool UnifyFormatPass::BasicProcess(const FuncGraphPtr &func_graph, bool main_gra return false; } } + ResetSubGraphInput(); return true; } @@ -858,13 +830,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap continue; } if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { - auto origin_inputs = cnode->inputs(); - for (size_t i = 3; i < cnode->size(); ++i) { - if (sub_inputs_map_.find(graph_name) != sub_inputs_map_.end() && - sub_inputs_map_[graph_name].find(cnode->input(i)) != sub_inputs_map_[graph_name].end()) { - cnode->set_input(i, sub_inputs_map_[graph_name][cnode->input(i)]); - } - } auto sub_func_graph = GetValueNode(cnode->input(1)); if (sub_func_graph == nullptr) { lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); @@ -882,7 +847,6 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap (void)DecreaseTransposeForSingleOp(sub_func_graph); SetSubGraphOutput(cnode, sub_func_graph); SetSubGraphAbstract(cnode, sub_func_graph); - cnode->set_inputs(origin_inputs); continue; } auto prim = GetValueNode(cnode->input(0)); @@ -904,6 +868,7 @@ bool UnifyFormatPass::DecreaseTransposeForSingleOp(const FuncGraphPtr &func_grap return false; } } + ResetSubGraphInput(); return true; } @@ -1010,8 +975,53 @@ bool UnifyFormatPass::ResetFuncGraph(const FuncGraphPtr &func_graph) { return true; } +bool UnifyFormatPass::JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph) { + MS_ASSERT(func_graph != nullptr); + auto node_list = TopoSort(func_graph->get_return()); + bool all_op_can_infer = true; + for (auto &node : node_list) { + if (!utils::isa(node)) { + continue; + } + auto cnode = node->cast(); + if (IsSpecialType(cnode)) { + continue; + } + if (CheckPrimitiveType(node, prim::kPrimIf) || CheckPrimitiveType(node, prim::kPrimWhile)) { + auto sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + all_op_can_infer = false; + } else { + all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); + } + sub_func_graph = GetValueNode(cnode->input(1)); + if (sub_func_graph == nullptr) { + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR); + all_op_can_infer = false; + } else { + all_op_can_infer = all_op_can_infer && JudgeAllOpsCanInfer(sub_func_graph); + } + continue; + } + auto cur_op_can_infer = node_infer_shape_.JudgeOpSupportInfer(cnode); + if (!cur_op_can_infer) { + auto prim = GetValueNode(cnode->input(0)); + MS_ASSERT(prim != nullptr); + lite::NotSupportOp::GetInstance()->InsertOp(prim->name()); + lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NOT_SUPPORT); + all_op_can_infer = false; + } + } + return all_op_can_infer; +} + bool UnifyFormatPass::RunOnlyForShape(const FuncGraphPtr &func_graph) { MS_ASSERT(func_graph != nullptr); + if (!JudgeAllOpsCanInfer(func_graph)) { + MS_LOG(ERROR) << "exist op cannot support infer shape."; + return false; + } need_reset_ = true; // insert transpose for some ops whose format must be NHWC, which is depend on framework. // In this process, transpose op cannot be fused to restore the original graph. @@ -1039,6 +1049,10 @@ bool UnifyFormatPass::Run(const FuncGraphPtr &func_graph) { return true; } } + if (!JudgeAllOpsCanInfer(func_graph)) { + MS_LOG(ERROR) << "exist op cannot support infer shape."; + return false; + } // insert transpose for some ops whose format must be NHWC, which is depend on framework. // In this process, tranpose can be fused, which the original graph may not be able to restored. if (!BasicProcess(func_graph, true)) { diff --git a/mindspore/lite/tools/optimizer/graph/unify_format_pass.h b/mindspore/lite/tools/optimizer/graph/unify_format_pass.h index 450090ee00..81a69b04f6 100644 --- a/mindspore/lite/tools/optimizer/graph/unify_format_pass.h +++ b/mindspore/lite/tools/optimizer/graph/unify_format_pass.h @@ -45,6 +45,7 @@ class UnifyFormatPass : public Pass { bool RunOnlyForShape(const FuncGraphPtr &func_graph); private: + bool JudgeAllOpsCanInfer(const FuncGraphPtr &func_graph); bool ResetFuncGraph(const FuncGraphPtr &func_graph); bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); bool DecreaseTransposeForSingleOp(const FuncGraphPtr &func_graph); @@ -61,11 +62,8 @@ class UnifyFormatPass : public Pass { STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, TransTypePair *trans_insert_info); STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector &perm); - void PreProcessFowardInsert(const FuncGraphPtr &func_graph, const CNodePtr &cnode, - std::unordered_map *match); - void PostProcessFowardInsert(const FuncGraphPtr &funcgraph, const CNodePtr &cnode, - const std::unordered_map &match); void SetSubGraphInput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); + void ResetSubGraphInput(); void SetSubGraphOutput(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); void SetSubGraphAbstract(const CNodePtr &cnode, const FuncGraphPtr &sub_graph); FmkType fmk_type_{lite::converter::FmkType_MS}; @@ -75,7 +73,7 @@ class UnifyFormatPass : public Pass { TransposeStrategy transpose_strategy_; std::set pre_insert_trans_; std::set post_insert_trans_; - std::unordered_map> sub_inputs_map_; + std::unordered_map> sub_inputs_map_; }; } // namespace opt } // namespace mindspore