From 9c005b182c7449988473a735588007c3d3874fda Mon Sep 17 00:00:00 2001 From: WilliamLian Date: Sat, 6 Jun 2020 15:25:17 +0800 Subject: [PATCH] convert dropoutdomask nodes's kernel info's first format which connected with same dropoutgenmask --- mindspore/ccsrc/kernel/kernel.h | 2 +- mindspore/ccsrc/kernel/kernel_build_info.cc | 15 ++ mindspore/ccsrc/kernel/kernel_build_info.h | 4 + mindspore/ccsrc/kernel/kernel_query.cc | 10 +- mindspore/ccsrc/operator/ops.cc | 1 + mindspore/ccsrc/operator/ops.h | 1 + .../ascend/ascend_backend_optimization.cc | 24 +-- .../ascend/ascend_backend_optimization.h | 1 - .../pre_activate/ascend/ascend_helper.cc | 2 +- .../format_type/insert_cast_for_runop.cc | 48 ------ .../format_type/insert_cast_for_runop.h | 35 ---- .../rectify_do_mask_kernel_info.cc | 154 ++++++++++++++++++ .../format_type/rectify_do_mask_kernel_info.h | 41 +++++ 13 files changed, 230 insertions(+), 108 deletions(-) delete mode 100644 mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc delete mode 100644 mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h create mode 100644 mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h diff --git a/mindspore/ccsrc/kernel/kernel.h b/mindspore/ccsrc/kernel/kernel.h index 271f6f20fa..211e81c684 100644 --- a/mindspore/ccsrc/kernel/kernel.h +++ b/mindspore/ccsrc/kernel/kernel.h @@ -31,7 +31,7 @@ enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, namespace kernel { -enum Axis { +enum Axis : int { N = 0, C, H, diff --git a/mindspore/ccsrc/kernel/kernel_build_info.cc b/mindspore/ccsrc/kernel/kernel_build_info.cc index ce7164a0d1..4ad75dc8a4 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.cc +++ b/mindspore/ccsrc/kernel/kernel_build_info.cc @@ -167,5 +167,20 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->op_pattern_ = pattern; } +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + if (index >= kernel_build_info_->inputs_format_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->inputs_format_[index] = format; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) { + MS_EXCEPTION_IF_NULL(kernel_build_info_); + if (index >= kernel_build_info_->outputs_format_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->outputs_format_[index] = format; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_build_info.h b/mindspore/ccsrc/kernel/kernel_build_info.h index d17b41a6fc..9207b0b863 100644 --- a/mindspore/ccsrc/kernel/kernel_build_info.h +++ b/mindspore/ccsrc/kernel/kernel_build_info.h @@ -131,6 +131,10 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOpPattern(OpPattern pattern); + void SetInputFormat(const std::string &format, size_t index); + + void SetOutputFormat(const std::string &format, size_t index); + std::shared_ptr Build(); private: diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 8d3ee64591..d2599fb881 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -41,8 +41,16 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, } else { MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; for (size_t index = 0; index < kernel_info_list->size(); ++index) { + std::ostringstream buffer; MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); - MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); + if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info_list->at(index)->GetOutputNum()) { + buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" + << " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetOutputNum() << "]"; + } else { + buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" + << " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetInputNum() << "]"; + } + MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString() << buffer.str(); } kernel_info_list->clear(); MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index ec9c6a58ee..1bb5548188 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -205,6 +205,7 @@ const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGr const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); +const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index fac88f3a50..8e849cc89e 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -211,6 +211,7 @@ extern const PrimitivePtr kPrimLayerNormGrad; extern const PrimitivePtr kPrimLayerNormXBackprop; extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; extern const PrimitivePtr kPrimDropoutGenMask; +extern const PrimitivePtr kPrimDropoutDoMask; extern const PrimitivePtr kPrimOneHot; extern const PrimitivePtr kPrimGelu; extern const PrimitivePtr kPrimGeluGrad; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 10e5e12db5..814f69e6db 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -54,6 +54,7 @@ #include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "pre_activate/ascend/format_type/insert_trans_op.h" +#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" #include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/optimize_dependence.h" #include "pre_activate/pass/erase_visit_attr.h" @@ -79,7 +80,6 @@ #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" #include "pre_activate/ascend/enhancer/add_memcpy_async.h" #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" -#include "pre_activate/ascend/format_type/insert_cast_for_runop.h" #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" #include "pre_activate/ascend/ir_fission/addn_fission.h" @@ -145,6 +145,7 @@ void RunOpAscendDataLayout(const std::shared_ptr &kernel_g MS_EXCEPTION_IF_NULL(kernel_graph); auto optimizer = std::make_shared(); auto data_layout_pm = std::make_shared("pynative_transop_pm"); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); @@ -157,30 +158,11 @@ void RunOpAscendDataLayout(const std::shared_ptr &kernel_g kernel_graph->SetExecOrderByDefault(); } -void RunOpAscendMixPrecision(const std::shared_ptr &kernel_graph) { - MS_EXCEPTION_IF_NULL(kernel_graph); - auto optimizer = std::make_shared(); - auto mixed_precision_pm = std::make_shared("pynative_transop_pm"); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - mixed_precision_pm->AddPass(std::make_shared()); - optimizer->AddPassManager(mixed_precision_pm); - (void)optimizer->Optimize(kernel_graph); - kernel_graph->SetExecOrderByDefault(); -} - void AscendDataLayout(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto optimizer = std::make_shared(); auto data_layout_pm = std::make_shared("transop_pm"); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h index 914b4c053a..46d9f9bd1b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h @@ -20,7 +20,6 @@ namespace mindspore { namespace opt { void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph); -void RunOpAscendMixPrecision(const std::shared_ptr &kernel_graph); void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); void AscendDataLayout(const std::shared_ptr &kernel_graph); void AscendMixPrecision(const std::shared_ptr &kernel_graph); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index 1203f4d406..982538c417 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -65,7 +65,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); input_node = AnfAlgo::GetInputNode(cnode, insert_index); - padding_axis = AnfAlgo::GetInputReshapeType(node, 0); + padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); } bool need_padding = false; if (is_insert_input) { diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc deleted file mode 100644 index 7647b86c17..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "pre_activate/ascend/format_type/insert_cast_for_runop.h" - -#include - -#include "device/kernel_info.h" -#include "pre_activate/ascend/ascend_helper.h" -#include "pre_activate/common/helper.h" -#include "kernel/oplib/oplib.h" -#include "session/anf_runtime_algorithm.h" -#include "utils/utils.h" - -namespace mindspore { -namespace opt { -const BaseRef RunOpInsertCast::DefinePattern() const { - VarPtr V = std::make_shared(UnVisited); - VarPtr Xs = std::make_shared(); - return VectorRef({V, Xs}); -} - -const AnfNodePtr RunOpInsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const EquivPtr &) const { - MS_EXCEPTION_IF_NULL(node); - if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { - return nullptr; - } - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); - // process input - CNodePtr cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - return InsertCastForInput(func_graph, cnode); -} -} // namespace opt -} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h b/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h deleted file mode 100644 index 4467cc5198..0000000000 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h +++ /dev/null @@ -1,35 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ -#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ -#include - -#include "pre_activate/common/optimizer.h" -#include "pre_activate/common/pattern_engine.h" -#include "ir/anf.h" -namespace mindspore { -namespace opt { -class RunOpInsertCast : public PatternProcessPass { - public: - explicit RunOpInsertCast(bool multigraph = true) : PatternProcessPass("insert_cast_for_runop", multigraph) {} - ~RunOpInsertCast() override = default; - const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; -}; -} // namespace opt -} // namespace mindspore - -#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc new file mode 100644 index 0000000000..82aad853c3 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc @@ -0,0 +1,154 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" + +#include +#include +#include +#include + +#include "session/anf_runtime_algorithm.h" +#include "kernel/kernel_build_info.h" +#include "utils/utils.h" +#include "kernel/common_utils.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const { + VarPtr X = std::make_shared(); + VarPtr Xs = std::make_shared(); + return VectorRef({X, Xs}); +} + +const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa()) { + return nullptr; + } + auto cnode = node->cast(); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->execution_mode() == kPynativeMode) { + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { + return nullptr; + } + auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); + if (do_mask_input_format != kOpFormat_DEFAULT) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(node)); + builder->SetInputFormat(kOpFormat_DEFAULT, 0); + builder->SetOutputFormat(kOpFormat_DEFAULT, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); + } + return nullptr; + } + if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { + return nullptr; + } + std::vector do_mask_node_list; + auto manager = graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto node_map = manager->node_users(); + auto iter = node_map.find(node); + if (iter == node_map.end()) { + MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!"; + } + auto gen_mask_output_nodes = iter->second; + for (const auto &output_node : gen_mask_output_nodes) { + if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { + auto output_cnode = output_node.first->cast(); + do_mask_node_list.push_back(output_cnode); + } + } + std::vector input_shape; + for (const auto &output_node : do_mask_node_list) { + if (input_shape.empty()) { + input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); + continue; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); + if (!kernel::IsSameShape(shape, input_shape)) { + MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!" + << " GenMask " << node->DebugString(); + } + } + RectifyKernelInfo(do_mask_node_list); + return nullptr; +} + +void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector &do_mask_node_list) const { + std::map format_counter; + std::string special_format; + std::string convert_format; + for (const auto &do_mask : do_mask_node_list) { + auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); + if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) { + special_format = do_mask_data_format; + } + if (format_counter.find(do_mask_data_format) == format_counter.end()) { + format_counter[do_mask_data_format] = 1; + } else { + format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; + } + // if has two or more special format we need change all domask's format to default that can avoid insert more + // transdata + if (format_counter.size() > 2) { + convert_format = kOpFormat_DEFAULT; + break; + } + if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() && + special_format != do_mask_data_format) { + convert_format = kOpFormat_DEFAULT; + break; + } + } + if (format_counter.size() == 1) { + return; + } + if (convert_format.empty()) { + convert_format = GetConvertFormat(format_counter); + } + RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format); +} + +std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map &format_counter) const { + std::string convert_format; + size_t counter = 0; + for (const auto &iter : format_counter) { + if (counter < iter.second) { + convert_format = iter.first; + } + if (counter == iter.second && kNeedTransFormatSet.find(convert_format) == kNeedTransFormatSet.end()) { + convert_format = iter.first; + } + } + return convert_format; +} +void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, + const std::string &format) const { + for (const auto &do_mask : do_mask_node_list) { + auto builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); + builder->SetInputFormat(format, 0); + builder->SetOutputFormat(format, 0); + AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); + } +} + +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h new file mode 100644 index 0000000000..83f7e397bd --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H +#include +#include +#include + +#include "pre_activate/common/optimizer.h" +namespace mindspore { +namespace opt { +class RectifyDoMaskKernelInfo : public PatternProcessPass { + public: + explicit RectifyDoMaskKernelInfo(bool multigraph = true) + : PatternProcessPass("batch_norm_bert_fission", multigraph) {} + ~RectifyDoMaskKernelInfo() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + void RectifyKernelInfo(const std::vector &do_mask_node_list) const; + std::string GetConvertFormat(const std::map &format_counter) const; + void RectifyDropOutDoMaskKernelInfo(const std::vector &do_mask_node_list, const std::string &format) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H