|
|
|
@@ -74,10 +74,21 @@ AnfNodePtr GetMul0(const FuncGraphPtr &graph, const AnfNodePtr &input2, const An |
|
|
|
} |
|
|
|
|
|
|
|
bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const AnfNodePtr &mul1_anf, |
|
|
|
const AnfNodePtr &reduce_sum) { |
|
|
|
const AnfNodePtr &reduce_sum, const AnfNodePtr &input2) { |
|
|
|
MS_EXCEPTION_IF_NULL(mul0_anf); |
|
|
|
MS_EXCEPTION_IF_NULL(mul1_anf); |
|
|
|
MS_EXCEPTION_IF_NULL(reduce_sum); |
|
|
|
MS_EXCEPTION_IF_NULL(input2); |
|
|
|
auto addn = input2->cast<CNodePtr>(); |
|
|
|
if (addn == nullptr || AnfAlgo::GetCNodeName(addn) != prim::kPrimAddN->name()) { |
|
|
|
MS_LOG(INFO) << "mul's second input is not addn"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0); |
|
|
|
if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) { |
|
|
|
MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
if (!mul0_anf->isa<CNode>() || !mul1_anf->isa<CNode>()) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
@@ -86,11 +97,6 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf |
|
|
|
auto mul0 = mul0_anf->cast<CNodePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(mul0); |
|
|
|
|
|
|
|
// when network is _VirtualDatasetCell, quit fusion |
|
|
|
if (mul0->fullname_with_scope().find("network-_VirtualDatasetCell") != std::string::npos) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
if (IsDepend(graph, mul0->input(1), reduce_sum)) { |
|
|
|
MS_LOG(INFO) << "mul0->input(1) depends on reduce_sum, quit fusion"; |
|
|
|
return true; |
|
|
|
@@ -128,7 +134,7 @@ const AnfNodePtr ConfusionMulGradFusion::Process(const FuncGraphPtr &graph, cons |
|
|
|
MS_LOG(INFO) << "Mul0 do not exist, quit fusion"; |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
if (QuitFusion(graph, mul0, mul1, node)) { |
|
|
|
if (QuitFusion(graph, mul0, mul1, node, input2)) { |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
|
|
|
|
|