From: @lianliguang Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjintags/v1.1.0
| @@ -71,7 +71,6 @@ | |||||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | ||||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | ||||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | ||||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||||
| #include "backend/optimizer/ascend/format_type/convert_cast_format.h" | #include "backend/optimizer/ascend/format_type/convert_cast_format.h" | ||||
| #include "backend/optimizer/pass/getitem_tuple.h" | #include "backend/optimizer/pass/getitem_tuple.h" | ||||
| #include "backend/optimizer/pass/optimize_dependence.h" | #include "backend/optimizer/pass/optimize_dependence.h" | ||||
| @@ -250,7 +249,6 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>()); | |||||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | ||||
| mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>()); | mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>()); | ||||
| optimizer->AddPassManager(mixed_precision_pm); | optimizer->AddPassManager(mixed_precision_pm); | ||||
| @@ -256,9 +256,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||||
| return trans_node; | return trans_node; | ||||
| } | } | ||||
| AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||||
| const TypeId &input_type, const TypeId &output_type, | |||||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type) { | |||||
| CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||||
| const TypeId &input_type, const TypeId &output_type, | |||||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type) { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| std::string input_format = format; | std::string input_format = format; | ||||
| std::string output_format = format; | std::string output_format = format; | ||||
| @@ -94,9 +94,9 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & | |||||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | ||||
| const bool need_padding, const std::string &op_name); | const bool need_padding, const std::string &op_name); | ||||
| AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||||
| const TypeId &input_type, const TypeId &output_type, | |||||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type); | |||||
| CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||||
| const TypeId &input_type, const TypeId &output_type, | |||||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type); | |||||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const KernelSelectPtr &kernel_select); | const KernelSelectPtr &kernel_select); | ||||
| @@ -26,8 +26,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| namespace { | |||||
| session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||||
| session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const { | |||||
| session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); | session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); | ||||
| AnfNodePtr cur_node = kernel_with_index.first; | AnfNodePtr cur_node = kernel_with_index.first; | ||||
| size_t cur_out_index = kernel_with_index.second; | size_t cur_out_index = kernel_with_index.second; | ||||
| @@ -62,8 +61,8 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||||
| return kernel_with_index; | return kernel_with_index; | ||||
| } | } | ||||
| void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, | |||||
| const size_t input_index) { | |||||
| void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const size_t output_index, const size_t input_index) const { | |||||
| // record the ref_pair | // record the ref_pair | ||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| @@ -72,9 +71,10 @@ void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr | |||||
| kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); | kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); | ||||
| } | } | ||||
| void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | |||||
| const AnfNodePtr &final_node, size_t final_index, | |||||
| const session::KernelWithIndex &origin_pair) { | |||||
| void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const AnfNodePtr &get_item, const AnfNodePtr &final_node, | |||||
| size_t final_index, | |||||
| const session::KernelWithIndex &origin_pair) const { | |||||
| // record the ref_pair | // record the ref_pair | ||||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| @@ -95,9 +95,10 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno | |||||
| // if get_item is nullptr, the additional node will link to the cnode | // if get_item is nullptr, the additional node will link to the cnode | ||||
| // else the additional node will link to the get_item node (the get_item node link to cnode) | // else the additional node will link to the get_item node (the get_item node link to cnode) | ||||
| AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | |||||
| size_t input_index, const AnfNodePtr &get_item) { | |||||
| AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); | |||||
| CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| size_t output_index, size_t input_index, | |||||
| const CNodePtr &get_item) const { | |||||
| CNodePtr final_node = (get_item == nullptr ? cnode : get_item); | |||||
| bool need_refresh_ref_addr = false; | bool need_refresh_ref_addr = false; | ||||
| size_t final_index = output_index; | size_t final_index = output_index; | ||||
| AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | ||||
| @@ -119,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| auto kernel_select = std::make_shared<KernelSelect>(); | auto kernel_select = std::make_shared<KernelSelect>(); | ||||
| final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | ||||
| RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); | RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); | ||||
| final_node = SplitTransdataIfNotSupported(func_graph, final_node); | |||||
| final_index = 0; | final_index = 0; | ||||
| need_refresh_ref_addr = true; | need_refresh_ref_addr = true; | ||||
| MS_EXCEPTION_IF_NULL(final_node); | MS_EXCEPTION_IF_NULL(final_node); | ||||
| @@ -148,15 +150,15 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| return final_node; | return final_node; | ||||
| } | } | ||||
| AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||||
| CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| auto ref_infos = op_info->ref_infos(); | auto ref_infos = op_info->ref_infos(); | ||||
| std::vector<AnfNodePtr> make_tuple_inputs; | std::vector<AnfNodePtr> make_tuple_inputs; | ||||
| AbstractBasePtrList abstract_list; | AbstractBasePtrList abstract_list; | ||||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | ||||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | ||||
| AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); | |||||
| CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); | |||||
| // deal with ref output | // deal with ref output | ||||
| if (ref_infos.count(output_index) != 0) { | if (ref_infos.count(output_index) != 0) { | ||||
| auto input_index = ref_infos.at(output_index); | auto input_index = ref_infos.at(output_index); | ||||
| @@ -167,14 +169,14 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP | |||||
| make_tuple_inputs.push_back(final_node); | make_tuple_inputs.push_back(final_node); | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||||
| MS_EXCEPTION_IF_NULL(make_tuple); | MS_EXCEPTION_IF_NULL(make_tuple); | ||||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | ||||
| return make_tuple; | return make_tuple; | ||||
| } | } | ||||
| AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||||
| CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| MS_EXCEPTION_IF_NULL(op_info); | MS_EXCEPTION_IF_NULL(op_info); | ||||
| auto ref_infos = op_info->ref_infos(); | auto ref_infos = op_info->ref_infos(); | ||||
| @@ -187,7 +189,6 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } // namespace | |||||
| const BaseRef DealRefTransAndCast::DefinePattern() const { | const BaseRef DealRefTransAndCast::DefinePattern() const { | ||||
| VarPtr V = std::make_shared<CondVar>(UnVisited); | VarPtr V = std::make_shared<CondVar>(UnVisited); | ||||
| @@ -195,7 +196,7 @@ const BaseRef DealRefTransAndCast::DefinePattern() const { | |||||
| return VectorRef({V, Xs}); | return VectorRef({V, Xs}); | ||||
| } | } | ||||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||||
| void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { | |||||
| if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { | if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { | ||||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | auto input_size = AnfAlgo::GetInputTensorNum(cnode); | ||||
| for (size_t i = 0; i < input_size; ++i) { | for (size_t i = 0; i < input_size; ++i) { | ||||
| @@ -238,5 +239,38 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, | |||||
| const CNodePtr &cnode) const { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || | |||||
| kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { | |||||
| if (IsFormatInvaild(cnode)) { | |||||
| return DoSplit(func_graph, cnode); | |||||
| } | |||||
| return cnode; | |||||
| } | |||||
| auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||||
| auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||||
| builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); | |||||
| builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); | |||||
| std::vector<AnfNodePtr> next_trans_node_inputs = { | |||||
| NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), cnode}; | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); | |||||
| next_trans_node->set_abstract(cnode->abstract()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); | |||||
| if (IsFormatInvaild(cnode)) { | |||||
| auto after_split_node = DoSplit(func_graph, cnode); | |||||
| AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); | |||||
| } | |||||
| if (IsFormatInvaild(next_trans_node)) { | |||||
| return DoSplit(func_graph, next_trans_node); | |||||
| } | |||||
| return next_trans_node; | |||||
| } | |||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -16,20 +16,37 @@ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ | #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ | ||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ | #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ | ||||
| #include <memory> | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "backend/optimizer/common/optimizer.h" | #include "backend/optimizer/common/optimizer.h" | ||||
| #include "backend/optimizer/ascend/ir_fission/transdata_split.h" | |||||
| #include "backend/optimizer/common/pattern_engine.h" | #include "backend/optimizer/common/pattern_engine.h" | ||||
| #include "backend/optimizer/ascend/ascend_helper.h" | #include "backend/optimizer/ascend/ascend_helper.h" | ||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class DealRefTransAndCast : public PatternProcessPass { | |||||
| class DealRefTransAndCast : public TransDataSplit { | |||||
| public: | public: | ||||
| explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} | |||||
| explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {} | |||||
| ~DealRefTransAndCast() override = default; | ~DealRefTransAndCast() override = default; | ||||
| const BaseRef DefinePattern() const override; | const BaseRef DefinePattern() const override; | ||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | ||||
| private: | |||||
| CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | |||||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | |||||
| CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) const; | |||||
| CNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||||
| const std::shared_ptr<kernel::OpInfo> &op_info) const; | |||||
| CNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | |||||
| size_t input_index, const CNodePtr &get_item) const; | |||||
| void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | |||||
| const AnfNodePtr &final_node, size_t final_index, | |||||
| const session::KernelWithIndex &origin_pair) const; | |||||
| void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, | |||||
| const size_t input_index) const; | |||||
| session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) const; | |||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -31,15 +31,6 @@ const BaseRef InsertTransOp::DefinePattern() const { | |||||
| return VectorRef({V, Xs}); | return VectorRef({V, Xs}); | ||||
| } | } | ||||
| bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||||
| auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}); | |||||
| auto iter = std::find(outputs.begin(), outputs.end(), node); | |||||
| if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) { | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | ||||
| const EquivPtr &) const { | const EquivPtr &) const { | ||||
| if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { | if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { | ||||
| @@ -1,66 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||||
| #include <vector> | |||||
| #include <memory> | |||||
| #include "backend/session/anf_runtime_algorithm.h" | |||||
| #include "utils/trace_base.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| const BaseRef SplitUnsupportedTransData::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| return VectorRef({prim::KPrimTransData, X}); | |||||
| } | |||||
| const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | |||||
| return nullptr; | |||||
| } | |||||
| auto ori_trans_data = node->cast<CNodePtr>(); | |||||
| if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) { | |||||
| return nullptr; | |||||
| } | |||||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data); | |||||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||||
| if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) { | |||||
| MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1" | |||||
| << ori_trans_data->DebugString() << trace::DumpSourceLines(node); | |||||
| } | |||||
| return SplitTransData(func_graph, ori_trans_data); | |||||
| } | |||||
| AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const { | |||||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node); | |||||
| if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || | |||||
| kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { | |||||
| return trans_node; | |||||
| } | |||||
| auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||||
| auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||||
| builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); | |||||
| builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); | |||||
| std::vector<AnfNodePtr> next_trans_node_inputs = { | |||||
| NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node}; | |||||
| auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); | |||||
| next_trans_node->set_abstract(trans_node->abstract()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get()); | |||||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); | |||||
| return next_trans_node; | |||||
| } | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||||
| #include "backend/optimizer/common/optimizer.h" | |||||
| namespace mindspore { | |||||
| namespace opt { | |||||
| class SplitUnsupportedTransData : public PatternProcessPass { | |||||
| public: | |||||
| explicit SplitUnsupportedTransData(bool multigraph = true) | |||||
| : PatternProcessPass("split_unsupported_transdata", multigraph) {} | |||||
| ~SplitUnsupportedTransData() override = default; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||||
| private: | |||||
| AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const; | |||||
| }; | |||||
| } // namespace opt | |||||
| } // namespace mindspore | |||||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||||
| @@ -27,22 +27,20 @@ const std::set<std::pair<string, string>> invalid_formats_pair = { | |||||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, | {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, | ||||
| {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | ||||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||||
| const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||||
| const EquivPtr &) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| bool changed = false; | |||||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||||
| for (auto &node : node_list) { | |||||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||||
| if (IsFormatInvaild(node)) { | |||||
| TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info())); | |||||
| changed = DoSplit(func_graph, node); | |||||
| } | |||||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||||
| if (IsFormatInvaild(node)) { | |||||
| TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info())); | |||||
| return DoSplit(func_graph, node); | |||||
| } | } | ||||
| } | } | ||||
| return changed; | |||||
| return nullptr; | |||||
| } | } | ||||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) const { | |||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(cnode); | MS_EXCEPTION_IF_NULL(cnode); | ||||
| @@ -52,8 +50,14 @@ bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||||
| return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | ||||
| } | } | ||||
| const BaseRef TransDataSplit::DefinePattern() const { | |||||
| VarPtr X = std::make_shared<Var>(); | |||||
| return VectorRef({prim::KPrimTransData, X}); | |||||
| } | |||||
| // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | ||||
| bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||||
| CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { | |||||
| MS_EXCEPTION_IF_NULL(func_graph); | MS_EXCEPTION_IF_NULL(func_graph); | ||||
| MS_EXCEPTION_IF_NULL(node); | MS_EXCEPTION_IF_NULL(node); | ||||
| auto cnode = node->cast<CNodePtr>(); | auto cnode = node->cast<CNodePtr>(); | ||||
| @@ -63,9 +67,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | auto input_format = AnfAlgo::GetInputFormat(node, 0); | ||||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | auto output_format = AnfAlgo::GetOutputFormat(node, 0); | ||||
| AnfNodePtr new_transdata_node = nullptr; | |||||
| AnfNodePtr new_transpose_node = nullptr; | |||||
| AnfNodePtr new_replace_node = nullptr; | |||||
| CNodePtr new_transdata_node = nullptr; | |||||
| CNodePtr new_transpose_node = nullptr; | |||||
| CNodePtr new_replace_node = nullptr; | |||||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | ||||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | // if output_format=default transdata need split transdata->transpose else transpose->transdata | ||||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | ||||
| @@ -96,16 +100,8 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||||
| new_transdata_node->set_abstract(node->abstract()); | new_transdata_node->set_abstract(node->abstract()); | ||||
| new_replace_node = new_transdata_node; | new_replace_node = new_transdata_node; | ||||
| } | } | ||||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||||
| MS_EXCEPTION_IF_NULL(manager); | |||||
| manager->AddFuncGraph(func_graph); | |||||
| if (!manager->Replace(node, new_replace_node)) { | |||||
| MS_LOG(EXCEPTION) << "Manager replace node failed" | |||||
| << " trace: " << trace::DumpSourceLines(node); | |||||
| } | |||||
| MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; | MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; | ||||
| return true; | |||||
| return new_replace_node; | |||||
| } | } | ||||
| } // namespace opt | } // namespace opt | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -29,15 +29,17 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace opt { | namespace opt { | ||||
| class TransDataSplit : public Pass { | |||||
| class TransDataSplit : public PatternProcessPass { | |||||
| public: | public: | ||||
| TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared<KernelSelect>()) {} | |||||
| explicit TransDataSplit(bool multigraph = true, const string &name = "trans_data_split") | |||||
| : PatternProcessPass(name, multigraph), kernel_select_(std::make_shared<KernelSelect>()) {} | |||||
| ~TransDataSplit() override = default; | ~TransDataSplit() override = default; | ||||
| bool Run(const FuncGraphPtr &graph) override; | |||||
| const BaseRef DefinePattern() const override; | |||||
| const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; | |||||
| private: | |||||
| bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | |||||
| bool IsFormatInvaild(const AnfNodePtr &node); | |||||
| protected: | |||||
| CNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; | |||||
| bool IsFormatInvaild(const AnfNodePtr &node) const; | |||||
| KernelSelectPtr kernel_select_; | KernelSelectPtr kernel_select_; | ||||
| }; | }; | ||||
| } // namespace opt | } // namespace opt | ||||
| @@ -481,13 +481,13 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | |||||
| CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | |||||
| auto idx = NewValueNode(SizeToLong(output_idx)); | auto idx = NewValueNode(SizeToLong(output_idx)); | ||||
| MS_EXCEPTION_IF_NULL(idx); | MS_EXCEPTION_IF_NULL(idx); | ||||
| auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx)); | auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx)); | ||||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | ||||
| idx->set_abstract(abstract_scalar); | idx->set_abstract(abstract_scalar); | ||||
| AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); | |||||
| CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); | |||||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | MS_EXCEPTION_IF_NULL(tuple_getitem); | ||||
| tuple_getitem->set_scope(node->scope()); | tuple_getitem->set_scope(node->scope()); | ||||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | ||||
| @@ -169,7 +169,7 @@ void HideNopNode(session::KernelGraph *const graph); | |||||
| void RemoveNopNode(session::KernelGraph *const graph); | void RemoveNopNode(session::KernelGraph *const graph); | ||||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||||
| CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | ||||