Merge pull request !1785 from lianliguang/add-dropout-kernel-special-kernel-select-rulestags/v0.5.0-beta
| @@ -31,7 +31,7 @@ enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL, | |||
| namespace kernel { | |||
| enum Axis { | |||
| enum Axis : int { | |||
| N = 0, | |||
| C, | |||
| H, | |||
| @@ -167,5 +167,20 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| kernel_build_info_->op_pattern_ = pattern; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| if (index >= kernel_build_info_->inputs_format_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| kernel_build_info_->inputs_format_[index] = format; | |||
| } | |||
| void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) { | |||
| MS_EXCEPTION_IF_NULL(kernel_build_info_); | |||
| if (index >= kernel_build_info_->outputs_format_.size()) { | |||
| MS_LOG(EXCEPTION) << "index outof range!"; | |||
| } | |||
| kernel_build_info_->outputs_format_[index] = format; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -131,6 +131,10 @@ class KernelBuildInfo::KernelBuildInfoBuilder { | |||
| void SetOpPattern(OpPattern pattern); | |||
| void SetInputFormat(const std::string &format, size_t index); | |||
| void SetOutputFormat(const std::string &format, size_t index); | |||
| std::shared_ptr<KernelBuildInfo> Build(); | |||
| private: | |||
| @@ -41,8 +41,16 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node, | |||
| } else { | |||
| MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; | |||
| for (size_t index = 0; index < kernel_info_list->size(); ++index) { | |||
| std::ostringstream buffer; | |||
| MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); | |||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString(); | |||
| if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info_list->at(index)->GetOutputNum()) { | |||
| buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]" | |||
| << " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetOutputNum() << "]"; | |||
| } else { | |||
| buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]" | |||
| << " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetInputNum() << "]"; | |||
| } | |||
| MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString() << buffer.str(); | |||
| } | |||
| kernel_info_list->clear(); | |||
| MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" | |||
| @@ -208,6 +208,7 @@ const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGr | |||
| const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); | |||
| const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); | |||
| const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); | |||
| const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask"); | |||
| const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); | |||
| const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); | |||
| const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); | |||
| @@ -214,6 +214,7 @@ extern const PrimitivePtr kPrimLayerNormGrad; | |||
| extern const PrimitivePtr kPrimLayerNormXBackprop; | |||
| extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; | |||
| extern const PrimitivePtr kPrimDropoutGenMask; | |||
| extern const PrimitivePtr kPrimDropoutDoMask; | |||
| extern const PrimitivePtr kPrimOneHot; | |||
| extern const PrimitivePtr kPrimGelu; | |||
| extern const PrimitivePtr kPrimGeluGrad; | |||
| @@ -55,6 +55,7 @@ | |||
| #include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" | |||
| #include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" | |||
| #include "pre_activate/ascend/format_type/insert_trans_op.h" | |||
| #include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include "pre_activate/pass/getitem_tuple.h" | |||
| #include "pre_activate/pass/optimize_dependence.h" | |||
| #include "pre_activate/pass/erase_visit_attr.h" | |||
| @@ -82,7 +83,6 @@ | |||
| #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" | |||
| #include "pre_activate/ascend/enhancer/add_memcpy_async.h" | |||
| #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" | |||
| #include "pre_activate/ascend/format_type/insert_cast_for_runop.h" | |||
| #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" | |||
| #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" | |||
| #include "pre_activate/ascend/ir_fission/addn_fission.h" | |||
| @@ -148,6 +148,7 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm"); | |||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | |||
| data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>()); | |||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| @@ -160,30 +161,11 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| } | |||
| void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto mixed_precision_pm = std::make_shared<PassManager>("pynative_transop_pm"); | |||
| mixed_precision_pm->AddPass(std::make_shared<RunOpInsertCast>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); | |||
| mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); | |||
| optimizer->AddPassManager(mixed_precision_pm); | |||
| (void)optimizer->Optimize(kernel_graph); | |||
| kernel_graph->SetExecOrderByDefault(); | |||
| } | |||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { | |||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||
| auto optimizer = std::make_shared<GraphOptimizer>(); | |||
| auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); | |||
| data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>()); | |||
| data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); | |||
| data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); | |||
| data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); | |||
| @@ -20,7 +20,6 @@ | |||
| namespace mindspore { | |||
| namespace opt { | |||
| void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); | |||
| @@ -65,7 +65,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt | |||
| dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); | |||
| dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); | |||
| input_node = AnfAlgo::GetInputNode(cnode, insert_index); | |||
| padding_axis = AnfAlgo::GetInputReshapeType(node, 0); | |||
| padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index); | |||
| } | |||
| bool need_padding = false; | |||
| if (is_insert_input) { | |||
| @@ -1,48 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/format_type/insert_cast_for_runop.h" | |||
| #include <memory> | |||
| #include "device/kernel_info.h" | |||
| #include "pre_activate/ascend/ascend_helper.h" | |||
| #include "pre_activate/common/helper.h" | |||
| #include "kernel/oplib/oplib.h" | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "utils/utils.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef RunOpInsertCast::DefinePattern() const { | |||
| VarPtr V = std::make_shared<CondVar>(UnVisited); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({V, Xs}); | |||
| } | |||
| const AnfNodePtr RunOpInsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| MS_EXCEPTION_IF_NULL(node); | |||
| if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) { | |||
| return nullptr; | |||
| } | |||
| AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node); | |||
| // process input | |||
| CNodePtr cnode = node->cast<CNodePtr>(); | |||
| MS_EXCEPTION_IF_NULL(cnode); | |||
| return InsertCastForInput(func_graph, cnode); | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -1,35 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||
| #include <string> | |||
| #include "pre_activate/common/optimizer.h" | |||
| #include "pre_activate/common/pattern_engine.h" | |||
| #include "ir/anf.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class RunOpInsertCast : public PatternProcessPass { | |||
| public: | |||
| explicit RunOpInsertCast(bool multigraph = true) : PatternProcessPass("insert_cast_for_runop", multigraph) {} | |||
| ~RunOpInsertCast() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_ | |||
| @@ -0,0 +1,154 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h" | |||
| #include <vector> | |||
| #include <map> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "session/anf_runtime_algorithm.h" | |||
| #include "kernel/kernel_build_info.h" | |||
| #include "utils/utils.h" | |||
| #include "kernel/common_utils.h" | |||
| #include "utils/context/ms_context.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const { | |||
| VarPtr X = std::make_shared<Var>(); | |||
| VarPtr Xs = std::make_shared<SeqVar>(); | |||
| return VectorRef({X, Xs}); | |||
| } | |||
| const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node, | |||
| const EquivPtr &) const { | |||
| if (node == nullptr || !node->isa<CNode>()) { | |||
| return nullptr; | |||
| } | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| auto ms_context = MsContext::GetInstance(); | |||
| MS_EXCEPTION_IF_NULL(ms_context); | |||
| if (ms_context->execution_mode() == kPynativeMode) { | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) { | |||
| return nullptr; | |||
| } | |||
| auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0); | |||
| if (do_mask_input_format != kOpFormat_DEFAULT) { | |||
| auto builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node)); | |||
| builder->SetInputFormat(kOpFormat_DEFAULT, 0); | |||
| builder->SetOutputFormat(kOpFormat_DEFAULT, 0); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get()); | |||
| } | |||
| return nullptr; | |||
| } | |||
| if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) { | |||
| return nullptr; | |||
| } | |||
| std::vector<CNodePtr> do_mask_node_list; | |||
| auto manager = graph->manager(); | |||
| MS_EXCEPTION_IF_NULL(manager); | |||
| auto node_map = manager->node_users(); | |||
| auto iter = node_map.find(node); | |||
| if (iter == node_map.end()) { | |||
| MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!"; | |||
| } | |||
| auto gen_mask_output_nodes = iter->second; | |||
| for (const auto &output_node : gen_mask_output_nodes) { | |||
| if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) { | |||
| auto output_cnode = output_node.first->cast<CNodePtr>(); | |||
| do_mask_node_list.push_back(output_cnode); | |||
| } | |||
| } | |||
| std::vector<size_t> input_shape; | |||
| for (const auto &output_node : do_mask_node_list) { | |||
| if (input_shape.empty()) { | |||
| input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); | |||
| continue; | |||
| } | |||
| auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0); | |||
| if (!kernel::IsSameShape(shape, input_shape)) { | |||
| MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!" | |||
| << " GenMask " << node->DebugString(); | |||
| } | |||
| } | |||
| RectifyKernelInfo(do_mask_node_list); | |||
| return nullptr; | |||
| } | |||
| void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const { | |||
| std::map<std::string, size_t> format_counter; | |||
| std::string special_format; | |||
| std::string convert_format; | |||
| for (const auto &do_mask : do_mask_node_list) { | |||
| auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0); | |||
| if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) { | |||
| special_format = do_mask_data_format; | |||
| } | |||
| if (format_counter.find(do_mask_data_format) == format_counter.end()) { | |||
| format_counter[do_mask_data_format] = 1; | |||
| } else { | |||
| format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1; | |||
| } | |||
| // if has two or more special format we need change all domask's format to default that can avoid insert more | |||
| // transdata | |||
| if (format_counter.size() > 2) { | |||
| convert_format = kOpFormat_DEFAULT; | |||
| break; | |||
| } | |||
| if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() && | |||
| special_format != do_mask_data_format) { | |||
| convert_format = kOpFormat_DEFAULT; | |||
| break; | |||
| } | |||
| } | |||
| if (format_counter.size() == 1) { | |||
| return; | |||
| } | |||
| if (convert_format.empty()) { | |||
| convert_format = GetConvertFormat(format_counter); | |||
| } | |||
| RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format); | |||
| } | |||
| std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const { | |||
| std::string convert_format; | |||
| size_t counter = 0; | |||
| for (const auto &iter : format_counter) { | |||
| if (counter < iter.second) { | |||
| convert_format = iter.first; | |||
| } | |||
| if (counter == iter.second && kNeedTransFormatSet.find(convert_format) == kNeedTransFormatSet.end()) { | |||
| convert_format = iter.first; | |||
| } | |||
| } | |||
| return convert_format; | |||
| } | |||
| void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, | |||
| const std::string &format) const { | |||
| for (const auto &do_mask : do_mask_node_list) { | |||
| auto builder = | |||
| std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask)); | |||
| builder->SetInputFormat(format, 0); | |||
| builder->SetOutputFormat(format, 0); | |||
| AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get()); | |||
| } | |||
| } | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H | |||
| #define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "pre_activate/common/optimizer.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| class RectifyDoMaskKernelInfo : public PatternProcessPass { | |||
| public: | |||
| explicit RectifyDoMaskKernelInfo(bool multigraph = true) | |||
| : PatternProcessPass("batch_norm_bert_fission", multigraph) {} | |||
| ~RectifyDoMaskKernelInfo() override = default; | |||
| const BaseRef DefinePattern() const override; | |||
| const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; | |||
| private: | |||
| void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const; | |||
| std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const; | |||
| void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const; | |||
| }; | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H | |||