|
|
@@ -566,8 +566,8 @@ void BufferFusion::MatchDepthwiseConvRelu(const CNodePtr &cnode, const session:: |
|
|
// Elemwise-->DepthwiseConvolution |
|
|
// Elemwise-->DepthwiseConvolution |
|
|
auto relu = cnode->input(1); |
|
|
auto relu = cnode->input(1); |
|
|
MS_EXCEPTION_IF_NULL(relu); |
|
|
MS_EXCEPTION_IF_NULL(relu); |
|
|
if (cnode->isa<CNode>() && AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() || |
|
|
|
|
|
AnfAlgo::GetCNodeName() == kReluV2OpName) { |
|
|
|
|
|
|
|
|
if (cnode->isa<CNode>() && |
|
|
|
|
|
(AnfAlgo::GetCNodeName(relu) == prim::kPrimRelu->name() || AnfAlgo::GetCNodeName(relu) == kReluV2OpName)) { |
|
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())}; |
|
|
std::vector<int> output_used_num{SizeToInt(manager->node_users()[relu].size())}; |
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); |
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); |
|
|
std::unordered_set<AnfNodePtr> record{cnode, relu}; |
|
|
std::unordered_set<AnfNodePtr> record{cnode, relu}; |
|
|
|