Browse Source

!1785 add dropout special kernel selected rules

Merge pull request !1785 from lianliguang/add-dropout-kernel-special-kernel-select-rules
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
686ced85ad
13 changed files with 230 additions and 108 deletions
  1. +1
    -1
      mindspore/ccsrc/kernel/kernel.h
  2. +15
    -0
      mindspore/ccsrc/kernel/kernel_build_info.cc
  3. +4
    -0
      mindspore/ccsrc/kernel/kernel_build_info.h
  4. +9
    -1
      mindspore/ccsrc/kernel/kernel_query.cc
  5. +1
    -0
      mindspore/ccsrc/operator/ops.cc
  6. +1
    -0
      mindspore/ccsrc/operator/ops.h
  7. +3
    -21
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc
  8. +0
    -1
      mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h
  9. +1
    -1
      mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc
  10. +0
    -48
      mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc
  11. +0
    -35
      mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h
  12. +154
    -0
      mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc
  13. +41
    -0
      mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h

+ 1
- 1
mindspore/ccsrc/kernel/kernel.h View File

@@ -31,7 +31,7 @@ enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AUTO_DIFF_KERNEL, AICPU_KERNEL,


namespace kernel { namespace kernel {


enum Axis {
enum Axis : int {
N = 0, N = 0,
C, C,
H, H,


+ 15
- 0
mindspore/ccsrc/kernel/kernel_build_info.cc View File

@@ -167,5 +167,20 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
MS_EXCEPTION_IF_NULL(kernel_build_info_); MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->op_pattern_ = pattern; kernel_build_info_->op_pattern_ = pattern;
} }
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
if (index >= kernel_build_info_->inputs_format_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->inputs_format_[index] = format;
}

void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
if (index >= kernel_build_info_->outputs_format_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->outputs_format_[index] = format;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 4
- 0
mindspore/ccsrc/kernel/kernel_build_info.h View File

@@ -131,6 +131,10 @@ class KernelBuildInfo::KernelBuildInfoBuilder {


void SetOpPattern(OpPattern pattern); void SetOpPattern(OpPattern pattern);


void SetInputFormat(const std::string &format, size_t index);

void SetOutputFormat(const std::string &format, size_t index);

std::shared_ptr<KernelBuildInfo> Build(); std::shared_ptr<KernelBuildInfo> Build();


private: private:


+ 9
- 1
mindspore/ccsrc/kernel/kernel_query.cc View File

@@ -41,8 +41,16 @@ void FilterInvalidKernelInfo(const CNodePtr &kernel_node,
} else { } else {
MS_LOG(WARNING) << "All kernel Info list does not match any kernel info "; MS_LOG(WARNING) << "All kernel Info list does not match any kernel info ";
for (size_t index = 0; index < kernel_info_list->size(); ++index) { for (size_t index = 0; index < kernel_info_list->size(); ++index) {
std::ostringstream buffer;
MS_EXCEPTION_IF_NULL(kernel_info_list->at(index)); MS_EXCEPTION_IF_NULL(kernel_info_list->at(index));
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString();
if (AnfAlgo::GetOutputTensorNum(kernel_node) != kernel_info_list->at(index)->GetOutputNum()) {
buffer << "Kernel node's output size [" << AnfAlgo::GetOutputTensorNum(kernel_node) << "]"
<< " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetOutputNum() << "]";
} else {
buffer << "Kernel node's output size [" << AnfAlgo::GetInputTensorNum(kernel_node) << "]"
<< " cannot match the kernel's output size [" << kernel_info_list->at(index)->GetInputNum() << "]";
}
MS_LOG(WARNING) << "kernel [ " << index << " ] :" << kernel_info_list->at(index)->ToString() << buffer.str();
} }
kernel_info_list->clear(); kernel_info_list->clear();
MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : [" MS_LOG(WARNING) << "node" << kernel_node->DebugString() << "'s output size : ["


+ 1
- 0
mindspore/ccsrc/operator/ops.cc View File

@@ -208,6 +208,7 @@ const PrimitivePtr kPrimLayerNormGrad = std::make_shared<Primitive>("LayerNormGr
const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop"); const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared<Primitive>("LayerNormXBackprop");
const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop"); const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared<Primitive>("LayerNormBetaGammaBackprop");
const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask"); const PrimitivePtr kPrimDropoutGenMask = std::make_shared<Primitive>("DropoutGenMask");
const PrimitivePtr kPrimDropoutDoMask = std::make_shared<Primitive>("DropoutDoMask");
const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot"); const PrimitivePtr kPrimOneHot = std::make_shared<Primitive>("OneHot");
const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu"); const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad"); const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");


+ 1
- 0
mindspore/ccsrc/operator/ops.h View File

@@ -214,6 +214,7 @@ extern const PrimitivePtr kPrimLayerNormGrad;
extern const PrimitivePtr kPrimLayerNormXBackprop; extern const PrimitivePtr kPrimLayerNormXBackprop;
extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop;
extern const PrimitivePtr kPrimDropoutGenMask; extern const PrimitivePtr kPrimDropoutGenMask;
extern const PrimitivePtr kPrimDropoutDoMask;
extern const PrimitivePtr kPrimOneHot; extern const PrimitivePtr kPrimOneHot;
extern const PrimitivePtr kPrimGelu; extern const PrimitivePtr kPrimGelu;
extern const PrimitivePtr kPrimGeluGrad; extern const PrimitivePtr kPrimGeluGrad;


+ 3
- 21
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc View File

@@ -55,6 +55,7 @@
#include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h" #include "pre_activate/ascend/ir_fusion/confusion_mul_grad_fusion.h"
#include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h" #include "pre_activate/ascend/ir_fusion/softmax_grad_ext_fusion.h"
#include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/ascend/format_type/insert_trans_op.h"
#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/getitem_tuple.h"
#include "pre_activate/pass/optimize_dependence.h" #include "pre_activate/pass/optimize_dependence.h"
#include "pre_activate/pass/erase_visit_attr.h" #include "pre_activate/pass/erase_visit_attr.h"
@@ -82,7 +83,6 @@
#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h"
#include "pre_activate/ascend/enhancer/add_memcpy_async.h" #include "pre_activate/ascend/enhancer/add_memcpy_async.h"
#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "pre_activate/ascend/format_type/insert_cast_for_runop.h"
#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h"
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
#include "pre_activate/ascend/ir_fission/addn_fission.h" #include "pre_activate/ascend/ir_fission/addn_fission.h"
@@ -148,6 +148,7 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm"); auto data_layout_pm = std::make_shared<PassManager>("pynative_transop_pm");
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>()); data_layout_pm->AddPass(std::make_shared<RunOpInsertTransData>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
@@ -160,30 +161,11 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
kernel_graph->SetExecOrderByDefault(); kernel_graph->SetExecOrderByDefault();
} }


void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>();
auto mixed_precision_pm = std::make_shared<PassManager>("pynative_transop_pm");
mixed_precision_pm->AddPass(std::make_shared<RunOpInsertCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
mixed_precision_pm->AddPass(std::make_shared<EliminateRedundantOp>());
mixed_precision_pm->AddPass(std::make_shared<OptimizeDependence>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<DealRefTransAndCast>());
mixed_precision_pm->AddPass(std::make_shared<GetitemTuple>());
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
optimizer->AddPassManager(mixed_precision_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
}

void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
auto optimizer = std::make_shared<GraphOptimizer>(); auto optimizer = std::make_shared<GraphOptimizer>();
auto data_layout_pm = std::make_shared<PassManager>("transop_pm"); auto data_layout_pm = std::make_shared<PassManager>("transop_pm");
data_layout_pm->AddPass(std::make_shared<RectifyDoMaskKernelInfo>());
data_layout_pm->AddPass(std::make_shared<InsertTransOp>()); data_layout_pm->AddPass(std::make_shared<InsertTransOp>());
data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());


+ 0
- 1
mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h View File

@@ -20,7 +20,6 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void RunOpAscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph); void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph); void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_graph);


+ 1
- 1
mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc View File

@@ -65,7 +65,7 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
input_node = AnfAlgo::GetInputNode(cnode, insert_index); input_node = AnfAlgo::GetInputNode(cnode, insert_index);
padding_axis = AnfAlgo::GetInputReshapeType(node, 0);
padding_axis = AnfAlgo::GetInputReshapeType(node, insert_index);
} }
bool need_padding = false; bool need_padding = false;
if (is_insert_input) { if (is_insert_input) {


+ 0
- 48
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.cc View File

@@ -1,48 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/format_type/insert_cast_for_runop.h"

#include <memory>

#include "device/kernel_info.h"
#include "pre_activate/ascend/ascend_helper.h"
#include "pre_activate/common/helper.h"
#include "kernel/oplib/oplib.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"

namespace mindspore {
namespace opt {
const BaseRef RunOpInsertCast::DefinePattern() const {
VarPtr V = std::make_shared<CondVar>(UnVisited);
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({V, Xs});
}

const AnfNodePtr RunOpInsertCast::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealCNodeKernel(node) || func_graph == nullptr) {
return nullptr;
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
// process input
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return InsertCastForInput(func_graph, cnode);
}
} // namespace opt
} // namespace mindspore

+ 0
- 35
mindspore/ccsrc/pre_activate/ascend/format_type/insert_cast_for_runop.h View File

@@ -1,35 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_
#include <string>

#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pattern_engine.h"
#include "ir/anf.h"
namespace mindspore {
namespace opt {
class RunOpInsertCast : public PatternProcessPass {
public:
explicit RunOpInsertCast(bool multigraph = true) : PatternProcessPass("insert_cast_for_runop", multigraph) {}
~RunOpInsertCast() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore

#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_INSERT_CAST_FOR_RUNOP_H_

+ 154
- 0
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.cc View File

@@ -0,0 +1,154 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h"

#include <vector>
#include <map>
#include <string>
#include <memory>

#include "session/anf_runtime_algorithm.h"
#include "kernel/kernel_build_info.h"
#include "utils/utils.h"
#include "kernel/common_utils.h"
#include "utils/context/ms_context.h"

namespace mindspore {
namespace opt {
const BaseRef RectifyDoMaskKernelInfo::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({X, Xs});
}

const AnfNodePtr RectifyDoMaskKernelInfo::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>()) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode) {
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutDoMask->name()) {
return nullptr;
}
auto do_mask_input_format = AnfAlgo::GetInputFormat(node, 0);
if (do_mask_input_format != kOpFormat_DEFAULT) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
builder->SetInputFormat(kOpFormat_DEFAULT, 0);
builder->SetOutputFormat(kOpFormat_DEFAULT, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
}
return nullptr;
}
if (AnfAlgo::GetCNodeName(cnode) != prim::kPrimDropoutGenMask->name()) {
return nullptr;
}
std::vector<CNodePtr> do_mask_node_list;
auto manager = graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto node_map = manager->node_users();
auto iter = node_map.find(node);
if (iter == node_map.end()) {
MS_LOG(EXCEPTION) << "Cannot find the node " << node->DebugString() << " in the graph manager!";
}
auto gen_mask_output_nodes = iter->second;
for (const auto &output_node : gen_mask_output_nodes) {
if (AnfAlgo::GetCNodeName(output_node.first) == prim::kPrimDropoutDoMask->name()) {
auto output_cnode = output_node.first->cast<CNodePtr>();
do_mask_node_list.push_back(output_cnode);
}
}
std::vector<size_t> input_shape;
for (const auto &output_node : do_mask_node_list) {
if (input_shape.empty()) {
input_shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0);
continue;
}
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(output_node, 0);
if (!kernel::IsSameShape(shape, input_shape)) {
MS_LOG(EXCEPTION) << "The DropOutGenMask connected with same genmask's shape must be equal!"
<< " GenMask " << node->DebugString();
}
}
RectifyKernelInfo(do_mask_node_list);
return nullptr;
}

void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const {
std::map<std::string, size_t> format_counter;
std::string special_format;
std::string convert_format;
for (const auto &do_mask : do_mask_node_list) {
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0);
if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) {
special_format = do_mask_data_format;
}
if (format_counter.find(do_mask_data_format) == format_counter.end()) {
format_counter[do_mask_data_format] = 1;
} else {
format_counter[do_mask_data_format] = format_counter[do_mask_data_format] + 1;
}
// if has two or more special format we need change all domask's format to default that can avoid insert more
// transdata
if (format_counter.size() > 2) {
convert_format = kOpFormat_DEFAULT;
break;
}
if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() &&
special_format != do_mask_data_format) {
convert_format = kOpFormat_DEFAULT;
break;
}
}
if (format_counter.size() == 1) {
return;
}
if (convert_format.empty()) {
convert_format = GetConvertFormat(format_counter);
}
RectifyDropOutDoMaskKernelInfo(do_mask_node_list, convert_format);
}

std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string, size_t> &format_counter) const {
std::string convert_format;
size_t counter = 0;
for (const auto &iter : format_counter) {
if (counter < iter.second) {
convert_format = iter.first;
}
if (counter == iter.second && kNeedTransFormatSet.find(convert_format) == kNeedTransFormatSet.end()) {
convert_format = iter.first;
}
}
return convert_format;
}
void RectifyDoMaskKernelInfo::RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list,
const std::string &format) const {
for (const auto &do_mask : do_mask_node_list) {
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(do_mask));
builder->SetInputFormat(format, 0);
builder->SetOutputFormat(format, 0);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), do_mask.get());
}
}

} // namespace opt
} // namespace mindspore

+ 41
- 0
mindspore/ccsrc/pre_activate/ascend/format_type/rectify_do_mask_kernel_info.h View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H
#include <map>
#include <string>
#include <vector>

#include "pre_activate/common/optimizer.h"
namespace mindspore {
namespace opt {
class RectifyDoMaskKernelInfo : public PatternProcessPass {
public:
explicit RectifyDoMaskKernelInfo(bool multigraph = true)
: PatternProcessPass("batch_norm_bert_fission", multigraph) {}
~RectifyDoMaskKernelInfo() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;

private:
void RectifyKernelInfo(const std::vector<CNodePtr> &do_mask_node_list) const;
std::string GetConvertFormat(const std::map<std::string, size_t> &format_counter) const;
void RectifyDropOutDoMaskKernelInfo(const std::vector<CNodePtr> &do_mask_node_list, const std::string &format) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_RECTIFY_DO_MASK_KERNEL_INFO_H

Loading…
Cancel
Save