From: @lianliguang Reviewed-by: @zhoufeng54,@chujinjin Signed-off-by: @chujinjintags/v1.1.0
| @@ -71,7 +71,6 @@ | |||
| #include "backend/optimizer/ascend/format_type/insert_transpose_for_dyanmic_gru_v2.h" | |||
| #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| #include "backend/optimizer/ascend/format_type/convert_cast_format.h" | |||
| #include "backend/optimizer/pass/getitem_tuple.h" | |||
| #include "backend/optimizer/pass/optimize_dependence.h" | |||
| @@ -250,7 +249,6 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap | |||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>()); | |||
| optimizer->AddPassManager(mixed_precision_pm); | |||
| @@ -256,9 +256,9 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, | |||
| return trans_node; | |||
| } | |||
| AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type) { | |||
| CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type) { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| std::string input_format = format; | |||
| std::string output_format = format; | |||
| @@ -94,9 +94,9 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & | |||
| CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, | |||
| const bool need_padding, const std::string &op_name); | |||
| AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type); | |||
| CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, | |||
| const TypeId &input_type, const TypeId &output_type, | |||
| const std::vector<size_t> &origin_shape, const TypeId &origin_type); | |||
| AnfNodePtr InsertTransOpForInput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const KernelSelectPtr &kernel_select); | |||
| @@ -26,8 +26,7 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace { | |||
| session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||
| session::KernelWithIndex DealRefTransAndCast::FindRefOriginNode(const AnfNodePtr &node) const { | |||
| session::KernelWithIndex kernel_with_index = AnfAlgo::VisitKernel(node, 0); | |||
| AnfNodePtr cur_node = kernel_with_index.first; | |||
| size_t cur_out_index = kernel_with_index.second; | |||
| @@ -62,8 +61,8 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) { | |||
| return kernel_with_index; | |||
| } | |||
| void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, | |||
| const size_t input_index) { | |||
| void DealRefTransAndCast::AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const size_t output_index, const size_t input_index) const { | |||
| // record the ref_pair | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| @@ -72,9 +71,10 @@ void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr | |||
| kernel_graph->AddRefCorrespondPairs(final_pair, kernel_with_index); | |||
| } | |||
| void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | |||
| const AnfNodePtr &final_node, size_t final_index, | |||
| const session::KernelWithIndex &origin_pair) { | |||
| void DealRefTransAndCast::AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const AnfNodePtr &get_item, const AnfNodePtr &final_node, | |||
| size_t final_index, | |||
| const session::KernelWithIndex &origin_pair) const { | |||
| // record the ref_pair | |||
| auto kernel_graph = func_graph->cast<KernelGraphPtr>(); | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| @@ -95,9 +95,10 @@ void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cno | |||
| // if get_item is nullptr, the additional node will link to the cnode | |||
| // else the additional node will link to the get_item node (the get_item node link to cnode) | |||
| AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | |||
| size_t input_index, const AnfNodePtr &get_item) { | |||
| AnfNodePtr final_node = (get_item == nullptr ? cnode : get_item); | |||
| CNodePtr DealRefTransAndCast::AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| size_t output_index, size_t input_index, | |||
| const CNodePtr &get_item) const { | |||
| CNodePtr final_node = (get_item == nullptr ? cnode : get_item); | |||
| bool need_refresh_ref_addr = false; | |||
| size_t final_index = output_index; | |||
| AnfNodePtr input_node = AnfAlgo::GetInputNode(cnode, input_index); | |||
| @@ -119,6 +120,7 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| auto kernel_select = std::make_shared<KernelSelect>(); | |||
| final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); | |||
| RefreshKernelBuildInfo(cur_format, origin_format, final_node, {}, cur_type); | |||
| final_node = SplitTransdataIfNotSupported(func_graph, final_node); | |||
| final_index = 0; | |||
| need_refresh_ref_addr = true; | |||
| MS_EXCEPTION_IF_NULL(final_node); | |||
| @@ -148,15 +150,15 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| return final_node; | |||
| } | |||
| AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||
| CNodePtr DealRefTransAndCast::DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| auto ref_infos = op_info->ref_infos(); | |||
| std::vector<AnfNodePtr> make_tuple_inputs; | |||
| AbstractBasePtrList abstract_list; | |||
| make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); | |||
| for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(cnode); ++output_index) { | |||
| AnfNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); | |||
| CNodePtr final_node = CreatTupleGetItemNode(func_graph, cnode, output_index); | |||
| // deal with ref output | |||
| if (ref_infos.count(output_index) != 0) { | |||
| auto input_index = ref_infos.at(output_index); | |||
| @@ -167,14 +169,14 @@ AnfNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodeP | |||
| make_tuple_inputs.push_back(final_node); | |||
| } | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| CNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); | |||
| MS_EXCEPTION_IF_NULL(make_tuple); | |||
| make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list)); | |||
| return make_tuple; | |||
| } | |||
| AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) { | |||
| CNodePtr DealRefTransAndCast::DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) const { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| MS_EXCEPTION_IF_NULL(op_info); | |||
| auto ref_infos = op_info->ref_infos(); | |||
| @@ -187,7 +189,6 @@ AnfNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cn | |||
| } | |||
| return nullptr; | |||
| } | |||
| } // namespace | |||
| const BaseRef DealRefTransAndCast::DefinePattern() const { | |||
| VarPtr V = std::make_shared<CondVar>(UnVisited); | |||
| @@ -195,7 +196,7 @@ const BaseRef DealRefTransAndCast::DefinePattern() const { | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { | |||
| void DealRefTransAndCast::DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const { | |||
| if (AnfAlgo::GetCNodeName(cnode) == kBroadcastOpName) { | |||
| auto input_size = AnfAlgo::GetInputTensorNum(cnode); | |||
| for (size_t i = 0; i < input_size; ++i) { | |||
| @@ -238,5 +239,38 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A | |||
| } | |||
| return nullptr; | |||
| } | |||
| CNodePtr DealRefTransAndCast::SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, | |||
| const CNodePtr &cnode) const { | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(cnode); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || | |||
| kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { | |||
| if (IsFormatInvaild(cnode)) { | |||
| return DoSplit(func_graph, cnode); | |||
| } | |||
| return cnode; | |||
| } | |||
| auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||
| auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||
| builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); | |||
| std::vector<AnfNodePtr> next_trans_node_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), cnode}; | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); | |||
| next_trans_node->set_abstract(cnode->abstract()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), cnode.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); | |||
| if (IsFormatInvaild(cnode)) { | |||
| auto after_split_node = DoSplit(func_graph, cnode); | |||
| AnfAlgo::SetNodeInput(next_trans_node, after_split_node, 0); | |||
| } | |||
| if (IsFormatInvaild(next_trans_node)) { | |||
| return DoSplit(func_graph, next_trans_node); | |||
| } | |||
| return next_trans_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -16,20 +16,37 @@ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_DEAL_REF_TRANS_AND_CAST_H_ | |||
| #include <memory> | |||
| #include "ir/anf.h" | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| #include "backend/optimizer/ascend/ir_fission/transdata_split.h" | |||
| #include "backend/optimizer/common/pattern_engine.h" | |||
| #include "backend/optimizer/ascend/ascend_helper.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class DealRefTransAndCast : public PatternProcessPass { | |||
| class DealRefTransAndCast : public TransDataSplit { | |||
| public: | |||
| explicit DealRefTransAndCast(bool multigraph = true) : PatternProcessPass("deal_ref_trans_and_cast", multigraph) {} | |||
| explicit DealRefTransAndCast(bool multigraph = true) : TransDataSplit(multigraph, "deal_ref_trans_and_cast") {} | |||
| ~DealRefTransAndCast() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| CNodePtr SplitTransdataIfNotSupported(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | |||
| void DealBroadCastAsRef(const FuncGraphPtr &func_graph, const CNodePtr &cnode) const; | |||
| CNodePtr DealRefSigleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) const; | |||
| CNodePtr DealRefForMultipleOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, | |||
| const std::shared_ptr<kernel::OpInfo> &op_info) const; | |||
| CNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, size_t output_index, | |||
| size_t input_index, const CNodePtr &get_item) const; | |||
| void AddRefPairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const AnfNodePtr &get_item, | |||
| const AnfNodePtr &final_node, size_t final_index, | |||
| const session::KernelWithIndex &origin_pair) const; | |||
| void AddRefNodePairToKernelGraph(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const size_t output_index, | |||
| const size_t input_index) const; | |||
| session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -31,15 +31,6 @@ const BaseRef InsertTransOp::DefinePattern() const { | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| bool IsGraphOutput(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { | |||
| auto outputs = AnfAlgo::GetAllOutput(func_graph->output(), {prim::kPrimTupleGetItem}); | |||
| auto iter = std::find(outputs.begin(), outputs.end(), node); | |||
| if (iter != outputs.end() && GetRealNodeNum(func_graph, node) == 1) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !AnfAlgo::IsRealKernel(node)) { | |||
| @@ -1,66 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" | |||
| #include <vector> | |||
| #include <memory> | |||
| #include "backend/session/anf_runtime_algorithm.h" | |||
| #include "utils/trace_base.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef SplitUnsupportedTransData::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| return VectorRef({prim::KPrimTransData, X}); | |||
| } | |||
| const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) { | |||
| return nullptr; | |||
| } | |||
| auto ori_trans_data = node->cast<CNodePtr>(); | |||
| if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) { | |||
| return nullptr; | |||
| } | |||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data); | |||
| MS_EXCEPTION_IF_NULL(kernel_info); | |||
| if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) { | |||
| MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1" | |||
| << ori_trans_data->DebugString() << trace::DumpSourceLines(node); | |||
| } | |||
| return SplitTransData(func_graph, ori_trans_data); | |||
| } | |||
| AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const { | |||
| auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node); | |||
| if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || | |||
| kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { | |||
| return trans_node; | |||
| } | |||
| auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||
| auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info); | |||
| builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); | |||
| builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); | |||
| std::vector<AnfNodePtr> next_trans_node_inputs = { | |||
| NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node}; | |||
| auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); | |||
| next_trans_node->set_abstract(trans_node->abstract()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get()); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); | |||
| return next_trans_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,37 +0,0 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||
| #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||
| #include "backend/optimizer/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class SplitUnsupportedTransData : public PatternProcessPass { | |||
| public: | |||
| explicit SplitUnsupportedTransData(bool multigraph = true) | |||
| : PatternProcessPass("split_unsupported_transdata", multigraph) {} | |||
| ~SplitUnsupportedTransData() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H | |||
| @@ -27,22 +27,20 @@ const std::set<std::pair<string, string>> invalid_formats_pair = { | |||
| {kOpFormat_C1HWNCoC0, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_FRACTAL_ZN_LSTM}, | |||
| {kOpFormat_FRACTAL_ZN_LSTM, kOpFormat_DEFAULT}, {kOpFormat_DEFAULT, kOpFormat_C1HWNCoC0}}; | |||
| bool TransDataSplit::Run(const FuncGraphPtr &func_graph) { | |||
| const AnfNodePtr TransDataSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| bool changed = false; | |||
| std::vector<AnfNodePtr> node_list = TopoSort(func_graph->get_return()); | |||
| for (auto &node : node_list) { | |||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||
| if (IsFormatInvaild(node)) { | |||
| TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info())); | |||
| changed = DoSplit(func_graph, node); | |||
| } | |||
| if (node != nullptr && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == kTransDataOpName) { | |||
| CheckCNodeInputSize(node->cast<CNodePtr>(), kBackendTransDataInputNum); | |||
| if (IsFormatInvaild(node)) { | |||
| TraceGuard guard(std::make_shared<TraceOpt>(node->debug_info())); | |||
| return DoSplit(func_graph, node); | |||
| } | |||
| } | |||
| return changed; | |||
| return nullptr; | |||
| } | |||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||
| bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| @@ -52,8 +50,14 @@ bool TransDataSplit::IsFormatInvaild(const AnfNodePtr &node) { | |||
| return invalid_formats_pair.find(format_pair) != invalid_formats_pair.end(); | |||
| } | |||
| const BaseRef TransDataSplit::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| return VectorRef({prim::KPrimTransData, X}); | |||
| } | |||
| // transdata cannot support frac_z to nchw need split transdata(frac_z-HWCN) and transpose(HWCN-NCHW) | |||
| bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { | |||
| CNodePtr TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const { | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| @@ -63,9 +67,9 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| auto input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| auto output_format = AnfAlgo::GetOutputFormat(node, 0); | |||
| AnfNodePtr new_transdata_node = nullptr; | |||
| AnfNodePtr new_transpose_node = nullptr; | |||
| AnfNodePtr new_replace_node = nullptr; | |||
| CNodePtr new_transdata_node = nullptr; | |||
| CNodePtr new_transpose_node = nullptr; | |||
| CNodePtr new_replace_node = nullptr; | |||
| auto padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); | |||
| // if output_format=default transdata need split transdata->transpose else transpose->transdata | |||
| if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { | |||
| @@ -96,16 +100,8 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n | |||
| new_transdata_node->set_abstract(node->abstract()); | |||
| new_replace_node = new_transdata_node; | |||
| } | |||
| FuncGraphManagerPtr manager = func_graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| manager->AddFuncGraph(func_graph); | |||
| if (!manager->Replace(node, new_replace_node)) { | |||
| MS_LOG(EXCEPTION) << "Manager replace node failed" | |||
| << " trace: " << trace::DumpSourceLines(node); | |||
| } | |||
| MS_LOG(INFO) << "Transdata node:" << cnode->DebugString() << "split success."; | |||
| return true; | |||
| return new_replace_node; | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -29,15 +29,17 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class TransDataSplit : public Pass { | |||
| class TransDataSplit : public PatternProcessPass { | |||
| public: | |||
| TransDataSplit() : Pass("trans_data_split"), kernel_select_(std::make_shared<KernelSelect>()) {} | |||
| explicit TransDataSplit(bool multigraph = true, const string &name = "trans_data_split") | |||
| : PatternProcessPass(name, multigraph), kernel_select_(std::make_shared<KernelSelect>()) {} | |||
| ~TransDataSplit() override = default; | |||
| bool Run(const FuncGraphPtr &graph) override; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; | |||
| private: | |||
| bool DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node); | |||
| bool IsFormatInvaild(const AnfNodePtr &node); | |||
| protected: | |||
| CNodePtr DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &node) const; | |||
| bool IsFormatInvaild(const AnfNodePtr &node) const; | |||
| KernelSelectPtr kernel_select_; | |||
| }; | |||
| } // namespace opt | |||
| @@ -481,13 +481,13 @@ bool IsNotRealUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node) { | |||
| return true; | |||
| } | |||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | |||
| CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx) { | |||
| auto idx = NewValueNode(SizeToLong(output_idx)); | |||
| MS_EXCEPTION_IF_NULL(idx); | |||
| auto imm = std::make_shared<Int64Imm>(SizeToLong(output_idx)); | |||
| auto abstract_scalar = std::make_shared<abstract::AbstractScalar>(imm); | |||
| idx->set_abstract(abstract_scalar); | |||
| AnfNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); | |||
| CNodePtr tuple_getitem = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, idx}); | |||
| MS_EXCEPTION_IF_NULL(tuple_getitem); | |||
| tuple_getitem->set_scope(node->scope()); | |||
| std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); | |||
| @@ -169,7 +169,7 @@ void HideNopNode(session::KernelGraph *const graph); | |||
| void RemoveNopNode(session::KernelGraph *const graph); | |||
| AnfNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||
| CNodePtr CreatTupleGetItemNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, size_t output_idx); | |||
| bool IsUsedByOthers(const FuncGraphPtr &graph, const AnfNodePtr &node); | |||