Browse Source

DepthwiseConv2d+Eltwise fusion pass

tags/v0.3.0-alpha
wangcong 6 years ago
parent
commit
0f42c66263
4 changed files with 48 additions and 2 deletions
  1. +8
    -2
      mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc
  2. +1
    -0
      mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h
  3. +37
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc
  4. +2
    -0
      mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h

+ 8
- 2
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc View File

@@ -38,6 +38,11 @@ constexpr auto kFusionKernelNamePrfix = "te_fusion";
constexpr auto kOptional = "optional_";
constexpr auto kOpFormat_FRACTAL_Z = "FRACTAL_Z";

std::map<std::string, std::string> TbeKernelBuild::buffer_fussion_op_map_ = {
{"DepthwiseConv2dNative", "DepthwiseConv2D"},
{"TensorAdd", "Add"}
};

std::string NormalizeFullScopeName(const string &full_scope_name) {
// exp:Default/ReLU-op0 -->Default_ReLU_op0
string normal_ret = full_scope_name;
@@ -825,8 +830,9 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n
(*compute_op_str)["output_desc"] = output_desc_list;
// gen others
auto type = AnfAlgo::GetCNodeName(cnode);
if (type == "TensorAdd") {
type = "Add";
// replace special op type for buffer fusion op
if (buffer_fussion_op_map_.find(type) != buffer_fussion_op_map_.end()) {
type = buffer_fussion_op_map_[type];
}
(*compute_op_str)["type"] = type;
tbe::TbeAdapter::NormalizeFuncName(&type);


+ 1
- 0
mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h View File

@@ -76,6 +76,7 @@ class TbeKernelBuild {
std::map<const AnfNodePtr, FusionDataType> *spec_data_input);
static bool IsDynamicInput(const CNodePtr &cnode);
static size_t GetOptionalInput(const CNodePtr &cnode, bool is_dynamic_input);
static std::map<std::string, std::string> buffer_fussion_op_map_;
};

class TbeKernelJsonCreator {


+ 37
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc View File

@@ -545,6 +545,39 @@ void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr
}
}

void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion, bool is_order) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(candidate_fusion);
auto manager = kernel_graph.manager();
MS_EXCEPTION_IF_NULL(manager);
if (is_order) {
// DepthwiseConvolution--->Elemwise
auto depthwise_conv = cnode->input(1);
MS_EXCEPTION_IF_NULL(depthwise_conv);
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(depthwise_conv) == prim::kPrimDepthwiseConv2dNative->name()) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv);
std::unordered_set<AnfNodePtr> record{cnode, depthwise_conv};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
} else {
// Elemwise-->DepthwiseConvolution
auto relu = cnode->input(1);
MS_EXCEPTION_IF_NULL(relu);
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() ||
AnfAlgo::GetCNodeName() == kReluV2OpName) {
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())};
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu);
std::unordered_set<AnfNodePtr> record{cnode, relu};
candidate_fusion->push_back(record);
SetRecordFusionId(record);
}
}
}

void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
MS_EXCEPTION_IF_NULL(candidate_fusion);
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
@@ -563,7 +596,11 @@ void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph,
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true);
}
} else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) {
MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false);
}
}
}


+ 2
- 0
mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.h View File

@@ -51,6 +51,8 @@ class BufferFusion : public Pass {
FusedNodeRecord *candidate_fusion);
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
FusedNodeRecord *candidate_fusion, bool is_order);
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion);


Loading…
Cancel
Save